|
| 1 | +# (C) Copyright 2025 Anemoi contributors. |
| 2 | +# |
| 3 | +# This software is licensed under the terms of the Apache Licence Version 2.0 |
| 4 | +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. |
| 5 | +# |
| 6 | +# In applying this licence, ECMWF does not waive the privileges and immunities |
| 7 | +# granted to it by virtue of its status as an intergovernmental organisation |
| 8 | +# nor does it submit to any jurisdiction. |
| 9 | + |
| 10 | +import logging |
| 11 | + |
| 12 | +import numpy as np |
| 13 | +import torch |
| 14 | + |
| 15 | +from anemoi.models.distributed.graph import gather_channels |
| 16 | +from anemoi.models.distributed.graph import shard_channels |
| 17 | +from anemoi.models.distributed.shapes import get_or_apply_shard_shapes |
| 18 | + |
| 19 | +LOGGER = logging.getLogger(__name__) |
| 20 | + |
| 21 | + |
| 22 | +class BaseTruncation: |
| 23 | + """Apply resolution truncation/upsampling via sparse projection matrices. |
| 24 | +
|
| 25 | + This utility holds two (optional) sparse COO matrices: |
| 26 | +
|
| 27 | + - ``A_down``: projects from a high-resolution representation to a coarse one |
| 28 | + (e.g., spectral/graph truncation). |
| 29 | + - ``A_up``: projects from the coarse representation back to high resolution |
| 30 | + (e.g., zero-padding or learned up-projection). |
| 31 | +
|
| 32 | + Both matrices are expected in SciPy CSR/COO-like format at construction time |
| 33 | + and are converted to PyTorch sparse tensors. During ``__call__`` the |
| 34 | + matrices are moved to the input device (first use) and applied per sample in |
| 35 | + the batch. When inputs are grid-sharded across ranks, tensors are reshaped |
| 36 | + to channel-sharding to apply the projection on the full sequence and then |
| 37 | + restored to their original sharding scheme. |
| 38 | +
|
| 39 | + Notes |
| 40 | + ----- |
| 41 | + - Sparse tensors are **not** registered as buffers because DDP does not |
| 42 | + reliably broadcast sparse tensors; instead the matrices are lazily moved |
| 43 | + to the correct device on first use. |
| 44 | + - Matrix–tensor multiplication is performed as ``A @ X`` (left |
| 45 | + multiplication), where ``A`` is sparse (``[n_out, n_in]``) and ``X`` is |
| 46 | + dense (``[n_in, d]``), producing ``[n_out, d]``. |
| 47 | + """ |
| 48 | + |
| 49 | + def __init__(self, truncation_data: dict) -> None: |
| 50 | + """Build the truncation matrices. |
| 51 | +
|
| 52 | + Parameters |
| 53 | + ---------- |
| 54 | + truncation_data : dict |
| 55 | + Dictionary possibly containing keys ``"down"`` and/or ``"up"`` with |
| 56 | + SciPy sparse matrices. ``"down"`` defines the high→coarse projection |
| 57 | + (stored as ``A_down``); ``"up"`` defines the coarse→high projection |
| 58 | + (stored as ``A_up``). |
| 59 | + """ |
| 60 | + self.A_down, self.A_up = None, None |
| 61 | + if "down" in truncation_data: |
| 62 | + self.A_down = self._make_truncation_matrix(truncation_data["down"]) |
| 63 | + LOGGER.info("Truncation: A_down %s", self.A_down.shape) |
| 64 | + if "up" in truncation_data: |
| 65 | + self.A_up = self._make_truncation_matrix(truncation_data["up"]) |
| 66 | + LOGGER.info("Truncation: A_up %s", self.A_up.shape) |
| 67 | + |
| 68 | + def _make_truncation_matrix(self, A, data_type=torch.float32): |
| 69 | + """Convert a SciPy sparse matrix to a coalesced PyTorch COO tensor. |
| 70 | +
|
| 71 | + Parameters |
| 72 | + ---------- |
| 73 | + A : scipy.sparse.spmatrix |
| 74 | + Input sparse matrix with shape ``(n_out, n_in)``. |
| 75 | + data_type : torch.dtype, optional |
| 76 | + Target dtype for the tensor values, by default ``torch.float32``. |
| 77 | +
|
| 78 | + Returns |
| 79 | + ------- |
| 80 | + torch.Tensor |
| 81 | + A coalesced sparse COO tensor with the same shape as ``A``. |
| 82 | + """ |
| 83 | + A_ = torch.sparse_coo_tensor( |
| 84 | + torch.tensor(np.vstack(A.nonzero()), dtype=torch.long), |
| 85 | + torch.tensor(A.data, dtype=data_type), |
| 86 | + size=A.shape, |
| 87 | + ).coalesce() |
| 88 | + return A_ |
| 89 | + |
| 90 | + def _multiply_sparse(self, x, A): |
| 91 | + """Left-multiply a dense matrix by a sparse projection. |
| 92 | +
|
| 93 | + Parameters |
| 94 | + ---------- |
| 95 | + x : torch.Tensor |
| 96 | + Dense 2-D tensor with shape ``(n_in, d)``. |
| 97 | + A : torch.Tensor |
| 98 | + Sparse COO tensor with shape ``(n_out, n_in)``. |
| 99 | +
|
| 100 | + Returns |
| 101 | + ------- |
| 102 | + torch.Tensor |
| 103 | + Dense 2-D tensor with shape ``(n_out, d)`` equal to ``A @ x``. |
| 104 | + """ |
| 105 | + return torch.sparse.mm(A, x) |
| 106 | + |
| 107 | + def _truncate_fields(self, x, A, batch_size=None, auto_cast=False): |
| 108 | + """Apply a sparse projection to each item in a batch. |
| 109 | +
|
| 110 | + Parameters |
| 111 | + ---------- |
| 112 | + x : torch.Tensor |
| 113 | + Dense 3-D tensor with shape ``(B, n_in, d)``. For each batch item |
| 114 | + ``i``, ``x[i]`` is multiplied as ``A @ x[i]``. |
| 115 | + A : torch.Tensor |
| 116 | + Sparse COO tensor with shape ``(n_out, n_in)``. |
| 117 | + batch_size : int, optional |
| 118 | + Number of batch elements to process. If ``None`` (default), uses |
| 119 | + ``x.shape[0]``. |
| 120 | + auto_cast : bool, optional |
| 121 | + If ``True``, enables CUDA autocast for the multiplication loop. |
| 122 | +
|
| 123 | + Returns |
| 124 | + ------- |
| 125 | + torch.Tensor |
| 126 | + Dense 3-D tensor with shape ``(B, n_out, d)`` containing the |
| 127 | + projected batch. |
| 128 | + """ |
| 129 | + if not batch_size: |
| 130 | + batch_size = x.shape[0] |
| 131 | + out = [] |
| 132 | + with torch.amp.autocast(device_type="cuda", enabled=auto_cast): |
| 133 | + for i in range(batch_size): |
| 134 | + out.append(self._multiply_sparse(x[i, ...], A)) |
| 135 | + return torch.stack(out) |
| 136 | + |
| 137 | + def __call__(self, x, grid_shard_shapes=None, model_comm_group=None): |
| 138 | + """Apply down/up truncation to a (possibly sharded) batch. |
| 139 | +
|
| 140 | + This function optionally: |
| 141 | + 1) Reshapes grid-sharded inputs to channel-sharded layout to expose the |
| 142 | + full sequence to the projection matrices. |
| 143 | + 2) Applies ``A_down`` (high→coarse) and/or ``A_up`` (coarse→high) per |
| 144 | + batch element when provided. |
| 145 | + 3) Restores the original sharding layout. |
| 146 | +
|
| 147 | + Parameters |
| 148 | + ---------- |
| 149 | + x : torch.Tensor |
| 150 | + Input dense tensor of shape ``(B, n_in, d)`` if unsharded. When |
| 151 | + grid-sharded, the leading dimensions depend on the sharding layout; |
| 152 | + this method will handle reshaping internally. |
| 153 | + grid_shard_shapes : Any, optional |
| 154 | + Distributed shape metadata used to convert between grid and |
| 155 | + channel sharding. If ``None``, no resharding is performed. |
| 156 | + model_comm_group : Any, optional |
| 157 | + Communication group handle used by distributed helpers. |
| 158 | +
|
| 159 | + Returns |
| 160 | + ------- |
| 161 | + torch.Tensor |
| 162 | + Output tensor with the same global shape semantics as ``x``. If |
| 163 | + truncation matrices are present, the ``n_in`` dimension is replaced |
| 164 | + by the corresponding ``n_out`` after projection. |
| 165 | + """ |
| 166 | + if self.A_down is not None or self.A_up is not None: |
| 167 | + if grid_shard_shapes is not None: |
| 168 | + shard_shapes = get_or_apply_shard_shapes(x, 0, grid_shard_shapes, model_comm_group) |
| 169 | + # grid-sharded input: reshard to channel-shards to apply truncation |
| 170 | + x = shard_channels(x, shard_shapes, model_comm_group) # we get the full sequence here |
| 171 | + |
| 172 | + # these can't be registered as buffers because ddp does not like to broadcast sparse tensors |
| 173 | + # hence we check that they are on the correct device ; copy should only happen in the first forward run |
| 174 | + if self.A_down is not None: |
| 175 | + self.A_down = self.A_down.to(x.device) |
| 176 | + x = self._truncate_fields(x, self.A_down) # to coarse resolution |
| 177 | + if self.A_up is not None: |
| 178 | + self.A_up = self.A_up.to(x.device) |
| 179 | + x = self._truncate_fields(x, self.A_up) # back to high resolution |
| 180 | + |
| 181 | + if grid_shard_shapes is not None: |
| 182 | + # back to grid-sharding as before |
| 183 | + x = gather_channels(x, shard_shapes, model_comm_group) |
| 184 | + |
| 185 | + return x |
0 commit comments