Skip to content

Commit abbef4b

Browse files
authored
fix: only load shards of grid into cpu mem if possible (ecmwf#83)
1 parent 41fcab6 commit abbef4b

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

training/src/anemoi/training/data/dataset.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,15 @@ def __iter__(self) -> torch.Tensor:
279279
end = i + (self.rollout + 1) * self.timeincrement
280280

281281
grid_shard_indices = self.grid_indices.get_shard_indices(self.reader_group_rank)
282-
x = self.data[start : end : self.timeincrement, :, :, :]
283-
x = x[..., grid_shard_indices] # select the grid shard
282+
if isinstance(grid_shard_indices, slice):
283+
# Load only shards into CPU memory
284+
x = self.data[start : end : self.timeincrement, :, :, grid_shard_indices]
285+
else:
286+
# Load full grid in CPU memory, select grid_shard after
287+
# Note that anemoi-datasets currently doesn't support slicing + indexing
288+
# in the same operation.
289+
x = self.data[start : end : self.timeincrement, :, :, :]
290+
x = x[..., grid_shard_indices] # select the grid shard
284291
x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables")
285292
self.ensemble_dim = 1
286293

0 commit comments

Comments
 (0)