Skip to content

Commit eb3edc5

Browse files
icedoom888Alberto Penninopre-commit-ci[bot]ssmmnn11
authored
feat: introducing AnemoiBaseModel (ecmwf#440)
[![Code Quality and Testing](https://github.com/ecmwf/anemoi-core/actions/workflows/python-pull-request.yml/badge.svg?branch=feat%2Fanemoibasemodel)](https://github.com/ecmwf/anemoi-core/actions/workflows/python-pull-request.yml) ## Description Following discussion in ecmwf#270 and PR ecmwf#399 , introducing a new base class on anemoi models to make easier integration of new tasks and models. ## What problem does this change solve? Makes inheritance easier and clearer for existing and upcoming models. Introduction of `_build_networks` internal function makes network elements more evident and clear. ## What issue or task does this change relate to? ecmwf#270 ## Additional notes ## <!-- Include any additional information, caveats, or considerations that the reviewer should be aware of. --> ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md) <!-- readthedocs-preview anemoi-training start --> ---- 📚 Documentation preview 📚: https://anemoi-training--440.org.readthedocs.build/en/440/ <!-- readthedocs-preview anemoi-training end --> <!-- readthedocs-preview anemoi-graphs start --> ---- 📚 Documentation preview 📚: https://anemoi-graphs--440.org.readthedocs.build/en/440/ <!-- readthedocs-preview anemoi-graphs end --> <!-- readthedocs-preview anemoi-models start --> ---- 📚 Documentation preview 📚: https://anemoi-models--440.org.readthedocs.build/en/440/ <!-- readthedocs-preview anemoi-models end --> --------- Co-authored-by: Alberto Pennino <apennino@santis-ln002.cscs.ch> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Lang <simon.lang@ecmwf.int>
1 parent 0e8ccad commit eb3edc5

File tree

14 files changed

+690
-371
lines changed

14 files changed

+690
-371
lines changed

models/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ This project is **BETA** and will be **Experimental** for the foreseeable future
77
Interfaces and functionality are likely to change, and the project itself may be scrapped.
88
**DO NOT** use this software in any project/software that is operational.
99

10-
Miscellanous tools for training data-driven weather forecasts.
10+
Miscellanous tools for training data-driven weather forecasting models.
1111

1212
## Documentation
1313

models/docs/introduction/overview.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ to process the input data.
8787
The layers are designed as extensible classes to allow for easy
8888
experimentation and switching out of components.
8989

90-
Mappers
91-
=======
90+
Graph Mappers
91+
=============
9292

9393
The layers implement `Mappers`, which maps data between the input grid
9494
and the internal hidden grid. The `Mappers` are used as encoder and

models/src/anemoi/models/distributed/shapes.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,10 @@ def apply_shard_shapes(tensor: Tensor, dim: int, shard_shapes_dim: list) -> list
3838
shard_shapes[i][dim] = shard_shape
3939

4040
return shard_shapes
41+
42+
43+
def get_or_apply_shard_shapes(x, dim=0, shard_shapes_dim: int = None, model_comm_group: Optional[ProcessGroup] = None):
44+
if shard_shapes_dim is None:
45+
return get_shard_shapes(x, dim, model_comm_group)
46+
else:
47+
return apply_shard_shapes(x, dim, shard_shapes_dim)

models/src/anemoi/models/layers/bounding.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@
1111

1212
from abc import ABC
1313
from abc import abstractmethod
14+
from typing import Any
15+
from typing import Iterable
1416
from typing import Optional
1517

1618
import torch
19+
from hydra.utils import instantiate
1720
from torch import nn
1821

1922
from anemoi.models.data_indices.tensor import InputTensorIndex
@@ -301,3 +304,54 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
301304
# Calculate the fraction of the total variable
302305
x[..., self.data_index] *= x[..., self.total_variable]
303306
return x
307+
308+
309+
def build_boundings(
310+
model_config: Any,
311+
data_indices: Any,
312+
statistics: dict | None,
313+
) -> nn.ModuleList:
314+
"""Build the list of model-output bounding modules from configuration.
315+
316+
This is a thin factory over Hydra's ``instantiate`` that reads the iterable
317+
``model_config.model.bounding`` and instantiates each entry while injecting
318+
the common keyword arguments required by bounding modules:
319+
``name_to_index``, ``statistics``, and ``name_to_index_stats``. The result
320+
is returned as an ``nn.ModuleList`` preserving the order of the config.
321+
322+
Parameters
323+
----------
324+
model_config : Any
325+
Object with a ``model`` attribute containing an iterable ``bounding``
326+
(e.g. a list of Hydra configs). If absent or empty, an empty
327+
``nn.ModuleList`` is returned.
328+
data_indices : Any
329+
Object providing the mappings:
330+
``data_indices.model.output.name_to_index`` and
331+
``data_indices.data.input.name_to_index``. These are forwarded to each
332+
instantiated bounding module as ``name_to_index`` and
333+
``name_to_index_stats`` respectively.
334+
statistics : dict | None
335+
Optional dataset/model statistics passed to each bounding module. Use
336+
``None`` if not required by the configured classes.
337+
338+
Returns
339+
-------
340+
torch.nn.ModuleList
341+
The instantiated bounding modules, in the same order as specified in
342+
``model_config.model.bounding``. May be empty.
343+
"""
344+
345+
bounding_cfgs: Iterable[Any] = getattr(getattr(model_config, "model", object()), "bounding", []) or []
346+
347+
return nn.ModuleList(
348+
[
349+
instantiate(
350+
cfg,
351+
name_to_index=data_indices.model.output.name_to_index,
352+
statistics=statistics,
353+
name_to_index_stats=data_indices.data.input.name_to_index,
354+
)
355+
for cfg in bounding_cfgs
356+
]
357+
)

models/src/anemoi/models/layers/mapper.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def run_processor_chunk_edge_sharding(
400400

401401
return self.post_process(x_dst_out, shapes[1], model_comm_group, keep_x_dst_sharded=True)
402402

403-
def forward_with_edge_sharding(
403+
def mapper_forward_with_edge_sharding(
404404
self,
405405
x: PairTensor,
406406
batch_size: int,
@@ -453,7 +453,7 @@ def forward_with_edge_sharding(
453453

454454
return out_dst
455455

456-
def forward_with_heads_sharding(
456+
def mapper_forward_with_heads_sharding(
457457
self,
458458
x: PairTensor,
459459
batch_size: int,
@@ -513,9 +513,9 @@ def forward(
513513
}
514514

515515
if self.shard_strategy == "edges":
516-
return self.forward_with_edge_sharding(**kwargs_forward)
516+
return self.mapper_forward_with_edge_sharding(**kwargs_forward)
517517
else: # self.shard_strategy == "heads"
518-
return self.forward_with_heads_sharding(**kwargs_forward)
518+
return checkpoint(self.mapper_forward_with_heads_sharding, **kwargs_forward, use_reentrant=False)
519519

520520

521521
class GraphTransformerForwardMapper(ForwardMapperPreProcessMixin, GraphTransformerBaseMapper):
@@ -818,7 +818,7 @@ def prepare_edges(
818818
edge_attr = self.emb_edges(edge_attr)
819819
return edge_attr, edge_index
820820

821-
def forward(
821+
def mapper_forward(
822822
self,
823823
x: PairTensor,
824824
batch_size: int,
@@ -852,6 +852,30 @@ def forward(
852852

853853
return x_src, x_dst
854854

855+
def forward(
856+
self,
857+
x: PairTensor,
858+
batch_size: int,
859+
shard_shapes: tuple[tuple[int], tuple[int]],
860+
model_comm_group: Optional[ProcessGroup] = None,
861+
x_src_is_sharded: bool = False,
862+
x_dst_is_sharded: bool = False,
863+
keep_x_dst_sharded: bool = False,
864+
**kwargs,
865+
) -> PairTensor:
866+
return checkpoint(
867+
self.mapper_forward,
868+
x=x,
869+
batch_size=batch_size,
870+
shard_shapes=shard_shapes,
871+
model_comm_group=model_comm_group,
872+
x_src_is_sharded=x_src_is_sharded,
873+
x_dst_is_sharded=x_dst_is_sharded,
874+
keep_x_dst_sharded=keep_x_dst_sharded,
875+
**kwargs,
876+
use_reentrant=False,
877+
)
878+
855879

856880
class GNNForwardMapper(ForwardMapperPreProcessMixin, GNNBaseMapper):
857881
"""Graph Neural Network Mapper data -> hidden."""
@@ -1155,7 +1179,7 @@ def __init__(
11551179

11561180
self.emb_nodes_dst = nn.Linear(self.in_channels_dst, self.hidden_dim)
11571181

1158-
def forward(
1182+
def mapper_forward(
11591183
self,
11601184
x: PairTensor,
11611185
batch_size: int,
@@ -1181,6 +1205,30 @@ def forward(
11811205

11821206
return x_dst
11831207

1208+
def forward(
1209+
self,
1210+
x: PairTensor,
1211+
batch_size: int,
1212+
shard_shapes: tuple[tuple[int], tuple[int]],
1213+
model_comm_group: Optional[ProcessGroup] = None,
1214+
x_src_is_sharded: bool = False,
1215+
x_dst_is_sharded: bool = False,
1216+
keep_x_dst_sharded: bool = False,
1217+
**kwargs,
1218+
) -> PairTensor:
1219+
return checkpoint(
1220+
self.mapper_forward,
1221+
x=x,
1222+
batch_size=batch_size,
1223+
shard_shapes=shard_shapes,
1224+
model_comm_group=model_comm_group,
1225+
x_src_is_sharded=x_src_is_sharded,
1226+
x_dst_is_sharded=x_dst_is_sharded,
1227+
keep_x_dst_sharded=keep_x_dst_sharded,
1228+
**kwargs,
1229+
use_reentrant=False,
1230+
)
1231+
11841232

11851233
class TransformerForwardMapper(ForwardMapperPreProcessMixin, TransformerBaseMapper):
11861234
"""Transformer Mapper from data -> hidden."""
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# (C) Copyright 2025 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
10+
import logging
11+
12+
import numpy as np
13+
import torch
14+
15+
from anemoi.models.distributed.graph import gather_channels
16+
from anemoi.models.distributed.graph import shard_channels
17+
from anemoi.models.distributed.shapes import get_or_apply_shard_shapes
18+
19+
LOGGER = logging.getLogger(__name__)
20+
21+
22+
class BaseTruncation:
23+
"""Apply resolution truncation/upsampling via sparse projection matrices.
24+
25+
This utility holds two (optional) sparse COO matrices:
26+
27+
- ``A_down``: projects from a high-resolution representation to a coarse one
28+
(e.g., spectral/graph truncation).
29+
- ``A_up``: projects from the coarse representation back to high resolution
30+
(e.g., zero-padding or learned up-projection).
31+
32+
Both matrices are expected in SciPy CSR/COO-like format at construction time
33+
and are converted to PyTorch sparse tensors. During ``__call__`` the
34+
matrices are moved to the input device (first use) and applied per sample in
35+
the batch. When inputs are grid-sharded across ranks, tensors are reshaped
36+
to channel-sharding to apply the projection on the full sequence and then
37+
restored to their original sharding scheme.
38+
39+
Notes
40+
-----
41+
- Sparse tensors are **not** registered as buffers because DDP does not
42+
reliably broadcast sparse tensors; instead the matrices are lazily moved
43+
to the correct device on first use.
44+
- Matrix–tensor multiplication is performed as ``A @ X`` (left
45+
multiplication), where ``A`` is sparse (``[n_out, n_in]``) and ``X`` is
46+
dense (``[n_in, d]``), producing ``[n_out, d]``.
47+
"""
48+
49+
def __init__(self, truncation_data: dict) -> None:
50+
"""Build the truncation matrices.
51+
52+
Parameters
53+
----------
54+
truncation_data : dict
55+
Dictionary possibly containing keys ``"down"`` and/or ``"up"`` with
56+
SciPy sparse matrices. ``"down"`` defines the high→coarse projection
57+
(stored as ``A_down``); ``"up"`` defines the coarse→high projection
58+
(stored as ``A_up``).
59+
"""
60+
self.A_down, self.A_up = None, None
61+
if "down" in truncation_data:
62+
self.A_down = self._make_truncation_matrix(truncation_data["down"])
63+
LOGGER.info("Truncation: A_down %s", self.A_down.shape)
64+
if "up" in truncation_data:
65+
self.A_up = self._make_truncation_matrix(truncation_data["up"])
66+
LOGGER.info("Truncation: A_up %s", self.A_up.shape)
67+
68+
def _make_truncation_matrix(self, A, data_type=torch.float32):
69+
"""Convert a SciPy sparse matrix to a coalesced PyTorch COO tensor.
70+
71+
Parameters
72+
----------
73+
A : scipy.sparse.spmatrix
74+
Input sparse matrix with shape ``(n_out, n_in)``.
75+
data_type : torch.dtype, optional
76+
Target dtype for the tensor values, by default ``torch.float32``.
77+
78+
Returns
79+
-------
80+
torch.Tensor
81+
A coalesced sparse COO tensor with the same shape as ``A``.
82+
"""
83+
A_ = torch.sparse_coo_tensor(
84+
torch.tensor(np.vstack(A.nonzero()), dtype=torch.long),
85+
torch.tensor(A.data, dtype=data_type),
86+
size=A.shape,
87+
).coalesce()
88+
return A_
89+
90+
def _multiply_sparse(self, x, A):
91+
"""Left-multiply a dense matrix by a sparse projection.
92+
93+
Parameters
94+
----------
95+
x : torch.Tensor
96+
Dense 2-D tensor with shape ``(n_in, d)``.
97+
A : torch.Tensor
98+
Sparse COO tensor with shape ``(n_out, n_in)``.
99+
100+
Returns
101+
-------
102+
torch.Tensor
103+
Dense 2-D tensor with shape ``(n_out, d)`` equal to ``A @ x``.
104+
"""
105+
return torch.sparse.mm(A, x)
106+
107+
def _truncate_fields(self, x, A, batch_size=None, auto_cast=False):
108+
"""Apply a sparse projection to each item in a batch.
109+
110+
Parameters
111+
----------
112+
x : torch.Tensor
113+
Dense 3-D tensor with shape ``(B, n_in, d)``. For each batch item
114+
``i``, ``x[i]`` is multiplied as ``A @ x[i]``.
115+
A : torch.Tensor
116+
Sparse COO tensor with shape ``(n_out, n_in)``.
117+
batch_size : int, optional
118+
Number of batch elements to process. If ``None`` (default), uses
119+
``x.shape[0]``.
120+
auto_cast : bool, optional
121+
If ``True``, enables CUDA autocast for the multiplication loop.
122+
123+
Returns
124+
-------
125+
torch.Tensor
126+
Dense 3-D tensor with shape ``(B, n_out, d)`` containing the
127+
projected batch.
128+
"""
129+
if not batch_size:
130+
batch_size = x.shape[0]
131+
out = []
132+
with torch.amp.autocast(device_type="cuda", enabled=auto_cast):
133+
for i in range(batch_size):
134+
out.append(self._multiply_sparse(x[i, ...], A))
135+
return torch.stack(out)
136+
137+
def __call__(self, x, grid_shard_shapes=None, model_comm_group=None):
138+
"""Apply down/up truncation to a (possibly sharded) batch.
139+
140+
This function optionally:
141+
1) Reshapes grid-sharded inputs to channel-sharded layout to expose the
142+
full sequence to the projection matrices.
143+
2) Applies ``A_down`` (high→coarse) and/or ``A_up`` (coarse→high) per
144+
batch element when provided.
145+
3) Restores the original sharding layout.
146+
147+
Parameters
148+
----------
149+
x : torch.Tensor
150+
Input dense tensor of shape ``(B, n_in, d)`` if unsharded. When
151+
grid-sharded, the leading dimensions depend on the sharding layout;
152+
this method will handle reshaping internally.
153+
grid_shard_shapes : Any, optional
154+
Distributed shape metadata used to convert between grid and
155+
channel sharding. If ``None``, no resharding is performed.
156+
model_comm_group : Any, optional
157+
Communication group handle used by distributed helpers.
158+
159+
Returns
160+
-------
161+
torch.Tensor
162+
Output tensor with the same global shape semantics as ``x``. If
163+
truncation matrices are present, the ``n_in`` dimension is replaced
164+
by the corresponding ``n_out`` after projection.
165+
"""
166+
if self.A_down is not None or self.A_up is not None:
167+
if grid_shard_shapes is not None:
168+
shard_shapes = get_or_apply_shard_shapes(x, 0, grid_shard_shapes, model_comm_group)
169+
# grid-sharded input: reshard to channel-shards to apply truncation
170+
x = shard_channels(x, shard_shapes, model_comm_group) # we get the full sequence here
171+
172+
# these can't be registered as buffers because ddp does not like to broadcast sparse tensors
173+
# hence we check that they are on the correct device ; copy should only happen in the first forward run
174+
if self.A_down is not None:
175+
self.A_down = self.A_down.to(x.device)
176+
x = self._truncate_fields(x, self.A_down) # to coarse resolution
177+
if self.A_up is not None:
178+
self.A_up = self.A_up.to(x.device)
179+
x = self._truncate_fields(x, self.A_up) # back to high resolution
180+
181+
if grid_shard_shapes is not None:
182+
# back to grid-sharding as before
183+
x = gather_channels(x, shard_shapes, model_comm_group)
184+
185+
return x

0 commit comments

Comments
 (0)