File tree Expand file tree Collapse file tree 1 file changed +9
-2
lines changed
training/src/anemoi/training/data Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Original file line number Diff line number Diff line change @@ -279,8 +279,15 @@ def __iter__(self) -> torch.Tensor:
279279 end = i + (self .rollout + 1 ) * self .timeincrement
280280
281281 grid_shard_indices = self .grid_indices .get_shard_indices (self .reader_group_rank )
282- x = self .data [start : end : self .timeincrement , :, :, :]
283- x = x [..., grid_shard_indices ] # select the grid shard
282+ if isinstance (grid_shard_indices , slice ):
283+ # Load only shards into CPU memory
284+ x = self .data [start : end : self .timeincrement , :, :, grid_shard_indices ]
285+ else :
286+ # Load full grid in CPU memory, select grid_shard after
287+ # Note that anemoi-datasets currently doesn't support slicing + indexing
288+ # in the same operation.
289+ x = self .data [start : end : self .timeincrement , :, :, :]
290+ x = x [..., grid_shard_indices ] # select the grid shard
284291 x = rearrange (x , "dates variables ensemble gridpoints -> dates ensemble gridpoints variables" )
285292 self .ensemble_dim = 1
286293
You can’t perform that action at this time.
0 commit comments