diff --git a/.github/workflows/changelog-release-update.yml b/.github/workflows/changelog-release-update.yml new file mode 100644 index 0000000..79b85ad --- /dev/null +++ b/.github/workflows/changelog-release-update.yml @@ -0,0 +1,34 @@ +# .github/workflows/update-changelog.yaml +name: "Update Changelog" + +on: + release: + types: [released] + +permissions: + pull-requests: write + contents: write + +jobs: + update: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ github.event.release.target_commitish }} + + - name: Update Changelog + uses: stefanzweifel/changelog-updater-action@v1 + with: + latest-version: ${{ github.event.release.tag_name }} + heading-text: ${{ github.event.release.name }} + + - name: Create Pull Request + uses: peter-evans/create-pull-request@v6 + with: + branch: docs/changelog-update-${{ github.event.release.tag_name }} + title: '[Changelog] Update to ${{ github.event.release.tag_name }}' + add-paths: | + CHANGELOG.md diff --git a/.gitignore b/.gitignore index d8baf06..d92403b 100644 --- a/.gitignore +++ b/.gitignore @@ -121,6 +121,7 @@ celerybeat.pid # Environments .env +.envrc .venv env/ venv/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f4b6367..c042b1f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,7 @@ repos: - id: check-added-large-files # Check for large files added to git - id: check-merge-conflict # Check for files that contain merge conflict - repo: https://github.com/psf/black-pre-commit-mirror - rev: 24.4.2 + rev: 24.8.0 hooks: - id: black args: [--line-length=120] @@ -34,7 +34,7 @@ repos: - --force-single-line-imports - --profile black - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.6 + rev: v0.6.3 hooks: - id: ruff # Next line if for documenation cod snippets @@ -65,6 +65,6 @@ repos: - id: optional-dependencies-all args: ["--inplace", "--exclude-keys=dev,docs,tests", "--group=dev=all,docs,tests"] - repo: https://github.com/tox-dev/pyproject-fmt - rev: "2.1.3" + rev: "2.2.1" hooks: - id: pyproject-fmt diff --git a/CHANGELOG.md b/CHANGELOG.md index c0b75f0..55fba06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,19 +8,24 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Please add your functional changes to the appropriate section in the PR. Keep it human-readable, your future self will thank you! -## [Unreleased] +## [Unreleased](https://github.com/ecmwf/anemoi-models/compare/0.3.0...HEAD) + +## [0.3.0](https://github.com/ecmwf/anemoi-models/compare/0.2.1...0.3.0) - Remapping of (meteorological) Variables ### Added +- CI workflow to update the changelog on release +- Remapper: Preprocessor for remapping one variable to multiple ones. Includes changes to the data indices since the remapper changes the number of variables. With optional config keywords. + ### Changed - - Update CI to inherit from common infrastructue reusable workflows - - run downstream-ci only when src and tests folders have changed - - New error messages for wrongs graphs. +- Update CI to inherit from common infrastructue reusable workflows +- run downstream-ci only when src and tests folders have changed +- New error messages for wrongs graphs. ### Removed -## [0.2.1] - Dependency update +## [0.2.1](https://github.com/ecmwf/anemoi-models/compare/0.2.0...0.2.1) - Dependency update ### Added @@ -31,7 +36,7 @@ Keep it human-readable, your future self will thank you! - anemoi-datasets dependency -## [0.2.0] - Support Heterodata +## [0.2.0](https://github.com/ecmwf/anemoi-models/compare/0.1.0...0.2.0) - Support Heterodata ### Added @@ -41,15 +46,12 @@ Keep it human-readable, your future self will thank you! - Updated to support new PyTorch Geometric HeteroData structure (defined by `anemoi-graphs` package). -## [0.1.0] - Initial Release +## [0.1.0](https://github.com/ecmwf/anemoi-models/releases/tag/0.1.0) - Initial Release ### Added + - Documentation - Initial code release with models, layers, distributed, preprocessing, and data_indices - Added Changelog -[unreleased]: https://github.com/ecmwf/anemoi-models/compare/0.2.1...HEAD -[0.2.1]: https://github.com/ecmwf/anemoi-models/compare/0.2.0...0.2.1 -[0.2.0]: https://github.com/ecmwf/anemoi-models/compare/0.1.0...0.2.0 -[0.1.0]: https://github.com/ecmwf/anemoi-models/releases/tag/0.1.0 diff --git a/docs/modules/data_indices.rst b/docs/modules/data_indices.rst index 8fbff4d..c546795 100644 --- a/docs/modules/data_indices.rst +++ b/docs/modules/data_indices.rst @@ -45,12 +45,33 @@ config entry: :alt: Schematic of IndexCollection with Data Indexing on Data and Model levels. :align: center -The are two Index-levels: +Additionally, prognostic and forcing variables can be remapped and +converted to multiple variables. The conversion is then done by the +remapper-preprocessor. + +.. code:: yaml + + data: + remapped: + d: + - "d_1" + - "d_2" + +There are two main Index-levels: - Data: The data at "Zarr"-level provided by Anemoi-Datasets - Model: The "squeezed" tensors with irrelevant parts missing. -These are both split into two versions: +Additionally, there are two internal model levels (After preprocessor +and before postprocessor) that are necessary because of the possiblity +to remap variables to multiple variables. + +- Internal Data: Variables from Data-level that are used internally in + the model, but not exposed to the user. +- Internal Model: Variables from Model-level that are used internally + in the model, but not exposed to the user. + +All indices at the different levels are split into two versions: - Input: The data going into training / model - Output: The data produced by training / model diff --git a/docs/modules/preprocessing.rst b/docs/modules/preprocessing.rst index e68898f..e1eab7f 100644 --- a/docs/modules/preprocessing.rst +++ b/docs/modules/preprocessing.rst @@ -33,3 +33,16 @@ following classes: :members: :no-undoc-members: :show-inheritance: + +********** + Remapper +********** + +The remapper module is used to remap one variable to multiple other +variables that have been listed in data.remapped:. The module contains +the following classes: + +.. automodule:: anemoi.models.preprocessing.remapper + :members: + :no-undoc-members: + :show-inheritance: diff --git a/src/anemoi/models/data_indices/collection.py b/src/anemoi/models/data_indices/collection.py index 3e325aa..266c11a 100644 --- a/src/anemoi/models/data_indices/collection.py +++ b/src/anemoi/models/data_indices/collection.py @@ -25,26 +25,76 @@ class IndexCollection: def __init__(self, config, name_to_index) -> None: self.config = OmegaConf.to_container(config, resolve=True) - + self.name_to_index = dict(sorted(name_to_index.items(), key=operator.itemgetter(1))) self.forcing = [] if config.data.forcing is None else OmegaConf.to_container(config.data.forcing, resolve=True) self.diagnostic = ( [] if config.data.diagnostic is None else OmegaConf.to_container(config.data.diagnostic, resolve=True) ) + # config.data.remapped is an optional dictionary with every remapper as one entry + self.remapped = ( + dict() + if config.data.get("remapped") is None + else OmegaConf.to_container(config.data.remapped, resolve=True) + ) + self.forcing_remapped = self.forcing.copy() assert set(self.diagnostic).isdisjoint(self.forcing), ( f"Diagnostic and forcing variables overlap: {set(self.diagnostic).intersection(self.forcing)}. ", "Please drop them at a dataset-level to exclude them from the training data.", ) - self.name_to_index = dict(sorted(name_to_index.items(), key=operator.itemgetter(1))) + assert set(self.remapped).isdisjoint(self.diagnostic), ( + "Remapped variable overlap with diagnostic variables. Not implemented.", + ) + assert set(self.remapped).issubset(self.name_to_index), ( + "Remapping a variable that does not exist in the dataset. Check for typos: ", + f"{set(self.remapped).difference(self.name_to_index)}", + ) name_to_index_model_input = { name: i for i, name in enumerate(key for key in self.name_to_index if key not in self.diagnostic) } name_to_index_model_output = { name: i for i, name in enumerate(key for key in self.name_to_index if key not in self.forcing) } + # remove remapped variables from internal data and model indices + name_to_index_internal_data_input = { + name: i for i, name in enumerate(key for key in self.name_to_index if key not in self.remapped) + } + name_to_index_internal_model_input = { + name: i for i, name in enumerate(key for key in name_to_index_model_input if key not in self.remapped) + } + name_to_index_internal_model_output = { + name: i for i, name in enumerate(key for key in name_to_index_model_output if key not in self.remapped) + } + # for all variables to be remapped we add the resulting remapped variables to the end of the tensors + # keep track of that in the index collections + for key in self.remapped: + for mapped in self.remapped[key]: + # add index of remapped variables to dictionary + name_to_index_internal_model_input[mapped] = len(name_to_index_internal_model_input) + name_to_index_internal_data_input[mapped] = len(name_to_index_internal_data_input) + if key not in self.forcing: + # do not include forcing variables in the remapped model output + name_to_index_internal_model_output[mapped] = len(name_to_index_internal_model_output) + else: + # add remapped forcing variables to forcing_remapped + self.forcing_remapped += [mapped] + if key in self.forcing: + # if key is in forcing we need to remove it from forcing_remapped after remapped variables have been added + self.forcing_remapped.remove(key) self.data = DataIndex(self.diagnostic, self.forcing, self.name_to_index) + self.internal_data = DataIndex( + self.diagnostic, + self.forcing_remapped, + name_to_index_internal_data_input, + ) # internal after the remapping applied to data (training) self.model = ModelIndex(self.diagnostic, self.forcing, name_to_index_model_input, name_to_index_model_output) + self.internal_model = ModelIndex( + self.diagnostic, + self.forcing_remapped, + name_to_index_internal_model_input, + name_to_index_internal_model_output, + ) # internal after the remapping applied to model (inference) def __repr__(self) -> str: return f"IndexCollection(config={self.config}, name_to_index={self.name_to_index})" @@ -54,7 +104,12 @@ def __eq__(self, other): # don't attempt to compare against unrelated types return NotImplemented - return self.model == other.model and self.data == other.data + return ( + self.model == other.model + and self.data == other.data + and self.internal_model == other.internal_model + and self.internal_data == other.internal_data + ) def __getitem__(self, key): return getattr(self, key) @@ -63,6 +118,8 @@ def todict(self): return { "data": self.data.todict(), "model": self.model.todict(), + "internal_model": self.internal_model.todict(), + "internal_data": self.internal_data.todict(), } @staticmethod diff --git a/src/anemoi/models/interface/__init__.py b/src/anemoi/models/interface/__init__.py index 54c548d..626940f 100644 --- a/src/anemoi/models/interface/__init__.py +++ b/src/anemoi/models/interface/__init__.py @@ -65,7 +65,7 @@ def _build_model(self) -> None: """Builds the model and pre- and post-processors.""" # Instantiate processors processors = [ - [name, instantiate(processor, statistics=self.statistics, data_indices=self.data_indices)] + [name, instantiate(processor, data_indices=self.data_indices, statistics=self.statistics)] for name, processor in self.config.data.processors.items() ] diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index 0f37474..3414dc5 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -104,22 +104,23 @@ def __init__( ) def _calculate_shapes_and_indices(self, data_indices: dict) -> None: - self.num_input_channels = len(data_indices.model.input) - self.num_output_channels = len(data_indices.model.output) - self._internal_input_idx = data_indices.model.input.prognostic - self._internal_output_idx = data_indices.model.output.prognostic + self.num_input_channels = len(data_indices.internal_model.input) + self.num_output_channels = len(data_indices.internal_model.output) + self._internal_input_idx = data_indices.internal_model.input.prognostic + self._internal_output_idx = data_indices.internal_model.output.prognostic def _assert_matching_indices(self, data_indices: dict) -> None: - assert len(self._internal_output_idx) == len(data_indices.model.output.full) - len( - data_indices.model.output.diagnostic + assert len(self._internal_output_idx) == len(data_indices.internal_model.output.full) - len( + data_indices.internal_model.output.diagnostic ), ( - f"Mismatch between the internal data indices ({len(self._internal_output_idx)}) and the output indices excluding " - f"diagnostic variables ({len(data_indices.model.output.full) - len(data_indices.model.output.diagnostic)})", + f"Mismatch between the internal data indices ({len(self._internal_output_idx)}) and " + f"the internal output indices excluding diagnostic variables " + f"({len(data_indices.internal_model.output.full) - len(data_indices.internal_model.output.diagnostic)})", ) assert len(self._internal_input_idx) == len( self._internal_output_idx, - ), f"Model indices must match {self._internal_input_idx} != {self._internal_output_idx}" + ), f"Internal model indices must match {self._internal_input_idx} != {self._internal_output_idx}" def _define_tensor_sizes(self, config: DotDict) -> None: self._data_grid_size = self._graph_data[self._graph_name_data].num_nodes diff --git a/src/anemoi/models/preprocessing/__init__.py b/src/anemoi/models/preprocessing/__init__.py index 081afaf..53017fb 100644 --- a/src/anemoi/models/preprocessing/__init__.py +++ b/src/anemoi/models/preprocessing/__init__.py @@ -14,6 +14,8 @@ from torch import Tensor from torch import nn +from anemoi.models.data_indices.collection import IndexCollection + LOGGER = logging.getLogger(__name__) @@ -23,19 +25,19 @@ class BasePreprocessor(nn.Module): def __init__( self, config=None, + data_indices: Optional[IndexCollection] = None, statistics: Optional[dict] = None, - data_indices: Optional[dict] = None, ) -> None: """Initialize the preprocessor. Parameters ---------- config : DotDict - configuration object + configuration object of the processor + data_indices : IndexCollection + Data indices for input and output variables statistics : dict Data statistics dictionary - data_indices : dict - Data indices for input and output variables """ super().__init__() diff --git a/src/anemoi/models/preprocessing/imputer.py b/src/anemoi/models/preprocessing/imputer.py index a7b0a8a..6ef5adb 100644 --- a/src/anemoi/models/preprocessing/imputer.py +++ b/src/anemoi/models/preprocessing/imputer.py @@ -33,16 +33,15 @@ def __init__( Parameters ---------- config : DotDict - configuration object + configuration object of the processor + data_indices : IndexCollection + Data indices for input and output variables statistics : dict Data statistics dictionary - data_indices : dict - Data indices for input and output variables """ - super().__init__(config, statistics, data_indices) + super().__init__(config, data_indices, statistics) self.nan_locations = None - self.data_indices = data_indices def _validate_indices(self): assert len(self.index_training_input) == len(self.index_inference_input) <= len(self.replacement), ( @@ -174,8 +173,8 @@ class InputImputer(BaseImputer): def __init__( self, config=None, + data_indices: Optional[IndexCollection] = None, statistics: Optional[dict] = None, - data_indices: Optional[dict] = None, ) -> None: super().__init__(config, data_indices, statistics) @@ -201,7 +200,10 @@ class ConstantImputer(BaseImputer): """ def __init__( - self, config=None, statistics: Optional[dict] = None, data_indices: Optional[IndexCollection] = None + self, + config=None, + data_indices: Optional[IndexCollection] = None, + statistics: Optional[dict] = None, ) -> None: super().__init__(config, data_indices, statistics) diff --git a/src/anemoi/models/preprocessing/normalizer.py b/src/anemoi/models/preprocessing/normalizer.py index 5bb97ee..bc75466 100644 --- a/src/anemoi/models/preprocessing/normalizer.py +++ b/src/anemoi/models/preprocessing/normalizer.py @@ -34,13 +34,13 @@ def __init__( Parameters ---------- config : DotDict - configuration object + configuration object of the processor + data_indices : IndexCollection + Data indices for input and output variables statistics : dict Data statistics dictionary - data_indices : dict - Data indices for input and output variables """ - super().__init__(config, statistics, data_indices) + super().__init__(config, data_indices, statistics) name_to_index_training_input = self.data_indices.data.input.name_to_index diff --git a/src/anemoi/models/preprocessing/remapper.py b/src/anemoi/models/preprocessing/remapper.py new file mode 100644 index 0000000..a79e2af --- /dev/null +++ b/src/anemoi/models/preprocessing/remapper.py @@ -0,0 +1,300 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +from abc import ABC +from typing import Optional + +import torch + +from anemoi.models.data_indices.collection import IndexCollection +from anemoi.models.preprocessing import BasePreprocessor + +LOGGER = logging.getLogger(__name__) + + +def cos_converter(x): + """Convert angle in degree to cos.""" + return torch.cos(x / 180 * torch.pi) + + +def sin_converter(x): + """Convert angle in degree to sin.""" + return torch.sin(x / 180 * torch.pi) + + +def atan2_converter(x): + """Convert cos and sin to angle in degree. + + Input: + x[..., 0]: cos + x[..., 1]: sin + """ + return torch.remainder(torch.atan2(x[..., 1], x[..., 0]) * 180 / torch.pi, 360) + + +class BaseRemapperVariable(BasePreprocessor, ABC): + """Base class for Remapping Variables.""" + + def __init__( + self, + config=None, + data_indices: Optional[IndexCollection] = None, + statistics: Optional[dict] = None, + ) -> None: + """Initialize the remapper. + + Parameters + ---------- + config : DotDict + configuration object of the processor + data_indices : IndexCollection + Data indices for input and output variables + statistics : dict + Data statistics dictionary + """ + super().__init__(config, data_indices, statistics) + + def _validate_indices(self): + assert len(self.index_training_input) == len(self.index_inference_input) <= len(self.remappers), ( + f"Error creating conversion indices {len(self.index_training_input)}, " + f"{len(self.index_inference_input)}, {len(self.remappers)}" + ) + assert len(self.index_training_output) == len(self.index_inference_output) <= len(self.remappers), ( + f"Error creating conversion indices {len(self.index_training_output)}, " + f"{len(self.index_inference_output)}, {len(self.remappers)}" + ) + assert len(set(self.index_training_input + self.indices_keep_training_input)) == self.num_training_input_vars, ( + "Error creating conversion indices: variables remapped in config.data.remapped " + "that have no remapping function defined. Preprocessed tensors contains empty columns." + ) + + def _create_remapping_indices( + self, + statistics=None, + ): + """Create the parameter indices for remapping.""" + # list for training and inference mode as position of parameters can change + name_to_index_training_input = self.data_indices.data.input.name_to_index + name_to_index_inference_input = self.data_indices.model.input.name_to_index + name_to_index_training_remapped_input = self.data_indices.internal_data.input.name_to_index + name_to_index_inference_remapped_input = self.data_indices.internal_model.input.name_to_index + name_to_index_training_remapped_output = self.data_indices.internal_data.output.name_to_index + name_to_index_inference_remapped_output = self.data_indices.internal_model.output.name_to_index + name_to_index_training_output = self.data_indices.data.output.name_to_index + name_to_index_inference_output = self.data_indices.model.output.name_to_index + + self.num_training_input_vars = len(name_to_index_training_input) + self.num_inference_input_vars = len(name_to_index_inference_input) + self.num_remapped_training_input_vars = len(name_to_index_training_remapped_input) + self.num_remapped_inference_input_vars = len(name_to_index_inference_remapped_input) + self.num_remapped_training_output_vars = len(name_to_index_training_remapped_output) + self.num_remapped_inference_output_vars = len(name_to_index_inference_remapped_output) + self.num_training_output_vars = len(name_to_index_training_output) + self.num_inference_output_vars = len(name_to_index_inference_output) + self.indices_keep_training_input = [] + for key, item in self.data_indices.data.input.name_to_index.items(): + if key in self.data_indices.internal_data.input.name_to_index: + self.indices_keep_training_input.append(item) + self.indices_keep_inference_input = [] + for key, item in self.data_indices.model.input.name_to_index.items(): + if key in self.data_indices.internal_model.input.name_to_index: + self.indices_keep_inference_input.append(item) + self.indices_keep_training_output = [] + for key, item in self.data_indices.data.output.name_to_index.items(): + if key in self.data_indices.internal_data.output.name_to_index: + self.indices_keep_training_output.append(item) + self.indices_keep_inference_output = [] + for key, item in self.data_indices.model.output.name_to_index.items(): + if key in self.data_indices.internal_model.output.name_to_index: + self.indices_keep_inference_output.append(item) + + ( + self.index_training_input, + self.index_training_remapped_input, + self.index_inference_input, + self.index_inference_remapped_input, + self.index_training_output, + self.index_training_backmapped_output, + self.index_inference_output, + self.index_inference_backmapped_output, + self.remappers, + self.backmappers, + ) = ([], [], [], [], [], [], [], [], [], []) + + # Create parameter indices for remapping variables + for name in name_to_index_training_input: + + method = self.methods.get(name, self.default) + + if method == "none": + continue + + if method == "cos_sin": + self.index_training_input.append(name_to_index_training_input[name]) + self.index_training_output.append(name_to_index_training_output[name]) + self.index_inference_input.append(name_to_index_inference_input[name]) + if name in name_to_index_inference_output: + self.index_inference_output.append(name_to_index_inference_output[name]) + else: + # this is a forcing variable. It is not in the inference output. + self.index_inference_output.append(None) + multiple_training_output, multiple_inference_output = [], [] + multiple_training_input, multiple_inference_input = [], [] + for name_dst in self.method_config[method][name]: + assert name_dst in self.data_indices.internal_data.input.name_to_index, ( + f"Trying to remap {name} to {name_dst}, but {name_dst} not a variable. " + f"Remap {name} to {name_dst} in config.data.remapped. " + ) + multiple_training_input.append(name_to_index_training_remapped_input[name_dst]) + multiple_training_output.append(name_to_index_training_remapped_output[name_dst]) + multiple_inference_input.append(name_to_index_inference_remapped_input[name_dst]) + if name_dst in name_to_index_inference_remapped_output: + multiple_inference_output.append(name_to_index_inference_remapped_output[name_dst]) + else: + # this is a forcing variable. It is not in the inference output. + multiple_inference_output.append(None) + + self.index_training_remapped_input.append(multiple_training_input) + self.index_inference_remapped_input.append(multiple_inference_input) + self.index_training_backmapped_output.append(multiple_training_output) + self.index_inference_backmapped_output.append(multiple_inference_output) + + self.remappers.append([cos_converter, sin_converter]) + self.backmappers.append(atan2_converter) + + LOGGER.info(f"Map {name} to cosine and sine and save result in {self.method_config[method][name]}.") + + else: + raise ValueError[f"Unknown remapping method for {name}: {method}"] + + def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + """Remap and convert the input tensor. + + ``` + x : torch.Tensor + Input tensor + in_place : bool + Whether to process the tensor in place. + in_place is not possible for this preprocessor. + ``` + """ + # Choose correct index based on number of variables + if x.shape[-1] == self.num_training_input_vars: + index = self.index_training_input + indices_remapped = self.index_training_remapped_input + indices_keep = self.indices_keep_training_input + target_number_columns = self.num_remapped_training_input_vars + + elif x.shape[-1] == self.num_inference_input_vars: + index = self.index_inference_input + indices_remapped = self.index_inference_remapped_input + indices_keep = self.indices_keep_inference_input + target_number_columns = self.num_remapped_inference_input_vars + + else: + raise ValueError( + f"Input tensor ({x.shape[-1]}) does not match the training " + f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})", + ) + + # create new tensor with target number of columns + x_remapped = torch.zeros(x.shape[:-1] + (target_number_columns,), dtype=x.dtype, device=x.device) + if in_place and not self.printed_preprocessor_warning: + LOGGER.warning( + "Remapper (preprocessor) called with in_place=True. This preprocessor cannot be applied in_place as new columns are added to the tensors.", + ) + self.printed_preprocessor_warning = True + + # copy variables that are not remapped + x_remapped[..., : len(indices_keep)] = x[..., indices_keep] + + # Remap variables + for idx_dst, remapper, idx_src in zip(indices_remapped, self.remappers, index): + if idx_src is not None: + for jj, ii in enumerate(idx_dst): + x_remapped[..., ii] = remapper[jj](x[..., idx_src]) + + return x_remapped + + def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + """Convert and remap the output tensor. + + ``` + x : torch.Tensor + Input tensor + in_place : bool + Whether to process the tensor in place. + in_place is not possible for this postprocessor. + ``` + """ + # Choose correct index based on number of variables + if x.shape[-1] == self.num_remapped_training_output_vars: + index = self.index_training_output + indices_remapped = self.index_training_backmapped_output + indices_keep = self.indices_keep_training_output + target_number_columns = self.num_training_output_vars + + elif x.shape[-1] == self.num_remapped_inference_output_vars: + index = self.index_inference_output + indices_remapped = self.index_inference_backmapped_output + indices_keep = self.indices_keep_inference_output + target_number_columns = self.num_inference_output_vars + + else: + raise ValueError( + f"Input tensor ({x.shape[-1]}) does not match the training " + f"({self.num_remapped_training_output_vars}) or inference shape ({self.num_remapped_inference_output_vars})", + ) + + # create new tensor with target number of columns + x_remapped = torch.zeros(x.shape[:-1] + (target_number_columns,), dtype=x.dtype, device=x.device) + if in_place and not self.printed_postprocessor_warning: + LOGGER.warning( + "Remapper (preprocessor) called with in_place=True. This preprocessor cannot be applied in_place as new columns are added to the tensors.", + ) + self.printed_postprocessor_warning = True + + # copy variables that are not remapped + x_remapped[..., indices_keep] = x[..., : len(indices_keep)] + + # Backmap variables + for idx_dst, backmapper, idx_src in zip(index, self.backmappers, indices_remapped): + if idx_dst is not None: + x_remapped[..., idx_dst] = backmapper(x[..., idx_src]) + + return x_remapped + + +class Remapper(BaseRemapperVariable): + """Remap and convert variables. + + cos_sin: + Remap the variable to cosine and sine. + Map output back to degrees. + + ``` + cos_sin: + "mwd" : ["cos_mwd", "sin_mwd"] + ``` + """ + + def __init__( + self, + config=None, + data_indices: Optional[IndexCollection] = None, + statistics: Optional[dict] = None, + ) -> None: + super().__init__(config, data_indices, statistics) + + self.printed_preprocessor_warning, self.printed_postprocessor_warning = False, False + + self._create_remapping_indices(statistics) + + self._validate_indices() diff --git a/tests/data_indices/test_collection.py b/tests/data_indices/test_collection.py index 5558c91..8505f8a 100644 --- a/tests/data_indices/test_collection.py +++ b/tests/data_indices/test_collection.py @@ -17,50 +17,100 @@ def data_indices(): config = DictConfig( { "data": { - "forcing": ["x"], + "forcing": ["x", "e"], "diagnostic": ["z", "q"], + "remapped": { + "e": ["e_1", "e_2"], + "d": ["d_1", "d_2"], + }, }, }, ) - name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4} + name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "e": 4, "d": 5, "other": 6} return IndexCollection(config=config, name_to_index=name_to_index) def test_dataindices_init(data_indices) -> None: - assert data_indices.data.input.includes == ["x"] + assert data_indices.data.input.includes == ["x", "e"] assert data_indices.data.input.excludes == ["z", "q"] + assert data_indices.internal_data.input.includes == ["x", "e_1", "e_2"] + assert data_indices.internal_data.input.excludes == ["z", "q"] + assert data_indices.internal_data.output.includes == ["z", "q"] + assert data_indices.internal_data.output.excludes == ["x", "e_1", "e_2"] assert data_indices.data.output.includes == ["z", "q"] - assert data_indices.data.output.excludes == ["x"] - assert data_indices.model.input.includes == ["x"] + assert data_indices.data.output.excludes == ["x", "e"] + assert data_indices.model.input.includes == ["x", "e"] assert data_indices.model.input.excludes == [] + assert data_indices.internal_model.input.includes == ["x", "e_1", "e_2"] + assert data_indices.internal_model.input.excludes == [] + assert data_indices.internal_model.output.includes == ["z", "q"] + assert data_indices.internal_model.output.excludes == [] assert data_indices.model.output.includes == ["z", "q"] assert data_indices.model.output.excludes == [] - assert data_indices.data.input.name_to_index == {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4} - assert data_indices.data.output.name_to_index == {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4} - assert data_indices.model.input.name_to_index == {"x": 0, "y": 1, "other": 2} - assert data_indices.model.output.name_to_index == {"y": 0, "z": 1, "q": 2, "other": 3} + assert data_indices.data.input.name_to_index == {"x": 0, "y": 1, "z": 2, "q": 3, "e": 4, "d": 5, "other": 6} + assert data_indices.internal_data.input.name_to_index == { + "x": 0, + "y": 1, + "z": 2, + "q": 3, + "other": 4, + "e_1": 5, + "e_2": 6, + "d_1": 7, + "d_2": 8, + } + assert data_indices.internal_data.output.name_to_index == { + "x": 0, + "y": 1, + "z": 2, + "q": 3, + "other": 4, + "e_1": 5, + "e_2": 6, + "d_1": 7, + "d_2": 8, + } + assert data_indices.data.output.name_to_index == {"x": 0, "y": 1, "z": 2, "q": 3, "e": 4, "d": 5, "other": 6} + assert data_indices.model.input.name_to_index == {"x": 0, "y": 1, "e": 2, "d": 3, "other": 4} + assert data_indices.internal_model.input.name_to_index == { + "x": 0, + "y": 1, + "other": 2, + "e_1": 3, + "e_2": 4, + "d_1": 5, + "d_2": 6, + } + assert data_indices.internal_model.output.name_to_index == {"y": 0, "z": 1, "q": 2, "other": 3, "d_1": 4, "d_2": 5} + assert data_indices.model.output.name_to_index == {"y": 0, "z": 1, "q": 2, "d": 3, "other": 4} def test_dataindices_max(data_indices) -> None: assert max(data_indices.data.input.full) == max(data_indices.data.input.name_to_index.values()) + assert max(data_indices.internal_data.input.full) == max(data_indices.internal_data.input.name_to_index.values()) + assert max(data_indices.internal_data.output.full) == max(data_indices.internal_data.output.name_to_index.values()) assert max(data_indices.data.output.full) == max(data_indices.data.output.name_to_index.values()) assert max(data_indices.model.input.full) == max(data_indices.model.input.name_to_index.values()) + assert max(data_indices.internal_model.input.full) == max(data_indices.internal_model.input.name_to_index.values()) + assert max(data_indices.internal_model.output.full) == max( + data_indices.internal_model.output.name_to_index.values() + ) assert max(data_indices.model.output.full) == max(data_indices.model.output.name_to_index.values()) def test_dataindices_todict(data_indices) -> None: expected_output = { "input": { - "full": torch.Tensor([0, 1, 4]).to(torch.int), - "forcing": torch.Tensor([0]).to(torch.int), + "full": torch.Tensor([0, 1, 4, 5, 6]).to(torch.int), + "forcing": torch.Tensor([0, 4]).to(torch.int), "diagnostic": torch.Tensor([2, 3]).to(torch.int), - "prognostic": torch.Tensor([1, 4]).to(torch.int), + "prognostic": torch.Tensor([1, 5, 6]).to(torch.int), }, "output": { - "full": torch.Tensor([1, 2, 3, 4]).to(torch.int), - "forcing": torch.Tensor([0]).to(torch.int), + "full": torch.Tensor([1, 2, 3, 5, 6]).to(torch.int), + "forcing": torch.Tensor([0, 4]).to(torch.int), "diagnostic": torch.Tensor([2, 3]).to(torch.int), - "prognostic": torch.Tensor([1, 4]).to(torch.int), + "prognostic": torch.Tensor([1, 5, 6]).to(torch.int), }, } @@ -70,19 +120,41 @@ def test_dataindices_todict(data_indices) -> None: assert torch.allclose(value, expected_output[key][subkey]) +def test_internaldataindices_todict(data_indices) -> None: + expected_output = { + "input": { + "full": torch.Tensor([0, 1, 4, 5, 6, 7, 8]).to(torch.int), + "forcing": torch.Tensor([0, 5, 6]).to(torch.int), + "diagnostic": torch.Tensor([2, 3]).to(torch.int), + "prognostic": torch.Tensor([1, 4, 7, 8]).to(torch.int), + }, + "output": { + "full": torch.Tensor([1, 2, 3, 4, 7, 8]).to(torch.int), + "forcing": torch.Tensor([0, 5, 6]).to(torch.int), + "diagnostic": torch.Tensor([2, 3]).to(torch.int), + "prognostic": torch.Tensor([1, 4, 7, 8]).to(torch.int), + }, + } + + for key in ["output", "input"]: + for subkey, value in data_indices.internal_data.todict()[key].items(): + assert subkey in expected_output[key] + assert torch.allclose(value, expected_output[key][subkey]) + + def test_modelindices_todict(data_indices) -> None: expected_output = { "input": { - "full": torch.Tensor([0, 1, 2]).to(torch.int), - "forcing": torch.Tensor([0]).to(torch.int), + "full": torch.Tensor([0, 1, 2, 3, 4]).to(torch.int), + "forcing": torch.Tensor([0, 2]).to(torch.int), "diagnostic": torch.Tensor([]).to(torch.int), - "prognostic": torch.Tensor([1, 2]).to(torch.int), + "prognostic": torch.Tensor([1, 3, 4]).to(torch.int), }, "output": { - "full": torch.Tensor([0, 1, 2, 3]).to(torch.int), + "full": torch.Tensor([0, 1, 2, 3, 4]).to(torch.int), "forcing": torch.Tensor([]).to(torch.int), "diagnostic": torch.Tensor([1, 2]).to(torch.int), - "prognostic": torch.Tensor([0, 3]).to(torch.int), + "prognostic": torch.Tensor([0, 3, 4]).to(torch.int), }, } @@ -90,3 +162,25 @@ def test_modelindices_todict(data_indices) -> None: for subkey, value in data_indices.model.todict()[key].items(): assert subkey in expected_output[key] assert torch.allclose(value, expected_output[key][subkey]) + + +def test_internalmodelindices_todict(data_indices) -> None: + expected_output = { + "input": { + "full": torch.Tensor([0, 1, 2, 3, 4, 5, 6]).to(torch.int), + "forcing": torch.Tensor([0, 3, 4]).to(torch.int), + "diagnostic": torch.Tensor([]).to(torch.int), + "prognostic": torch.Tensor([1, 2, 5, 6]).to(torch.int), + }, + "output": { + "full": torch.Tensor([0, 1, 2, 3, 4, 5]).to(torch.int), + "forcing": torch.Tensor([]).to(torch.int), + "diagnostic": torch.Tensor([1, 2]).to(torch.int), + "prognostic": torch.Tensor([0, 3, 4, 5]).to(torch.int), + }, + } + + for key in ["output", "input"]: + for subkey, value in data_indices.internal_model.todict()[key].items(): + assert subkey in expected_output[key] + assert torch.allclose(value, expected_output[key][subkey]) diff --git a/tests/preprocessing/test_preprocessor_imputer.py b/tests/preprocessing/test_preprocessor_imputer.py index ea04b9a..5218efc 100644 --- a/tests/preprocessing/test_preprocessor_imputer.py +++ b/tests/preprocessing/test_preprocessor_imputer.py @@ -26,6 +26,7 @@ def non_default_input_imputer(): "imputer": {"default": "none", "mean": ["y"], "maximum": ["x"], "none": ["z"], "minimum": ["q"]}, "forcing": ["z", "q"], "diagnostic": ["other"], + "remapped": {}, }, }, ) @@ -37,7 +38,7 @@ def non_default_input_imputer(): } name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4} data_indices = IndexCollection(config=config, name_to_index=name_to_index) - return InputImputer(config=config.data.imputer, statistics=statistics, data_indices=data_indices) + return InputImputer(config=config.data.imputer, data_indices=data_indices, statistics=statistics) @pytest.fixture() @@ -49,6 +50,7 @@ def default_input_imputer(): "imputer": {"default": "minimum"}, "forcing": ["z", "q"], "diagnostic": ["other"], + "remapped": [], }, }, ) @@ -86,6 +88,7 @@ def default_constant_imputer(): "imputer": {"default": "none", 0: ["x"], 3.0: ["y"], 22.7: ["z"], 10: ["q"]}, "forcing": ["z", "q"], "diagnostic": ["other"], + "remapped": [], }, }, ) @@ -103,6 +106,7 @@ def non_default_constant_imputer(): "imputer": {"default": 22.7}, "forcing": ["z", "q"], "diagnostic": ["other"], + "remapped": [], }, }, ) diff --git a/tests/preprocessing/test_preprocessor_normalizer.py b/tests/preprocessing/test_preprocessor_normalizer.py index 787079d..cc527e7 100644 --- a/tests/preprocessing/test_preprocessor_normalizer.py +++ b/tests/preprocessing/test_preprocessor_normalizer.py @@ -25,6 +25,7 @@ def input_normalizer(): "normalizer": {"default": "mean-std", "min-max": ["x"], "max": ["y"], "none": ["z"], "mean-std": ["q"]}, "forcing": ["z", "q"], "diagnostic": ["other"], + "remapped": {}, }, }, ) @@ -36,7 +37,7 @@ def input_normalizer(): } name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4} data_indices = IndexCollection(config=config, name_to_index=name_to_index) - return InputNormalizer(config=config.data.normalizer, statistics=statistics, data_indices=data_indices) + return InputNormalizer(config=config.data.normalizer, data_indices=data_indices, statistics=statistics) def test_normalizer_not_inplace(input_normalizer) -> None: diff --git a/tests/preprocessing/test_preprocessor_remapper.py b/tests/preprocessing/test_preprocessor_remapper.py new file mode 100644 index 0000000..86bdfde --- /dev/null +++ b/tests/preprocessing/test_preprocessor_remapper.py @@ -0,0 +1,67 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import pytest +import torch +from omegaconf import DictConfig + +from anemoi.models.data_indices.collection import IndexCollection +from anemoi.models.preprocessing.remapper import Remapper + + +@pytest.fixture() +def input_remapper(): + config = DictConfig( + { + "diagnostics": {"log": {"code": {"level": "DEBUG"}}}, + "data": { + "remapper": { + "cos_sin": { + "d": ["cos_d", "sin_d"], + } + }, + "forcing": ["z", "q"], + "diagnostic": ["other"], + "remapped": { + "d": ["cos_d", "sin_d"], + }, + }, + }, + ) + statistics = {} + name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "d": 4, "other": 5} + data_indices = IndexCollection(config=config, name_to_index=name_to_index) + return Remapper(config=config.data.remapper, data_indices=data_indices, statistics=statistics) + + +def test_remap_not_inplace(input_remapper) -> None: + x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) + input_remapper(x, in_place=False) + assert torch.allclose(x, torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]])) + + +def test_remap(input_remapper) -> None: + x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) + expected_output = torch.Tensor( + [[1.0, 2.0, 3.0, 4.0, 5.0, -0.8660254, 0.5], [6.0, 7.0, 8.0, 9.0, 10.0, -0.93358043, -0.35836795]] + ) + assert torch.allclose(input_remapper.transform(x), expected_output) + + +def test_inverse_transform(input_remapper) -> None: + x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0, -0.8660254, 0.5], [6.0, 7.0, 8.0, 9.0, 10.0, -0.93358043, -0.35836795]]) + expected_output = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) + assert torch.allclose(input_remapper.inverse_transform(x), expected_output) + + +def test_remap_inverse_transform(input_remapper) -> None: + x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) + assert torch.allclose( + input_remapper.inverse_transform(input_remapper.transform(x, in_place=False), in_place=False), x + )