3030_logger = logging .getLogger (__name__ )
3131
3232
33+ def is_ndarray (obj : typing .Any ) -> bool :
34+ """Check if object is an ndarray (wraps the linter warning)."""
35+ return isinstance (obj , (np .ndarray )) # noqa: TID251
36+
37+
3338@dataclasses .dataclass
3439class IOReaderData :
3540 """
@@ -58,10 +63,10 @@ def create(cls, other: typing.Any) -> "IOReaderData":
5863
5964 other should be such an instance.
6065 """
61- coords = other .coords
62- geoinfos = other .geoinfos
63- data = other .data
64- datetimes = other .datetimes
66+ coords = np . asarray ( other .coords )
67+ geoinfos = np . asarray ( other .geoinfos )
68+ data = np . asarray ( other .data )
69+ datetimes = np . asarray ( other .datetimes )
6570
6671 n_datapoints = len (data )
6772
@@ -130,22 +135,22 @@ class OutputDataset:
130135 item_key : ItemKey
131136
132137 # (datapoints, channels, ens)
133- data : zarr .Array # wrong type => array like
138+ data : zarr .Array | NDArray # wrong type => array like
134139
135140 # (datapoints,)
136- times : zarr .Array
141+ times : zarr .Array | NDArray
137142
138143 # (datapoints, 2)
139- coords : zarr .Array
144+ coords : zarr .Array | NDArray
140145
141146 # (datapoints, geoinfos) geoinfos are stream dependent => 0 for most gridded data
142- geoinfo : zarr .Array
147+ geoinfo : zarr .Array | NDArray
143148
144149 channels : list [str ]
145150 geoinfo_channels : list [str ]
146151
147152 @functools .cached_property
148- def arrays (self ) -> dict [str , zarr .Array ]:
153+ def arrays (self ) -> dict [str , zarr .Array | NDArray ]:
149154 """Iterate over the arrays and their names."""
150155 return {
151156 "data" : self .data ,
@@ -236,7 +241,8 @@ def write_zarr(self, item: OutputItem):
236241 """Write one output item to the zarr store."""
237242 group = self ._get_group (item .key , create = True )
238243 for dataset in item .datasets :
239- self ._write_dataset (group , dataset )
244+ if dataset is not None :
245+ self ._write_dataset (group , dataset )
240246
241247 def get_data (self , sample : int , stream : str , forecast_step : int ) -> OutputItem :
242248 """Get datasets for the output item matching the arguments."""
@@ -285,6 +291,7 @@ def _write_arrays(self, dataset_group: zarr.Group, dataset: OutputDataset):
285291 self ._create_dataset (dataset_group , array_name , array )
286292
287293 def _create_dataset (self , group : zarr .Group , name : str , array : NDArray ):
294+ assert is_ndarray (array ), f"Expected ndarray but got: { type (array )} "
288295 if array .size == 0 : # sometimes for geoinfo
289296 chunks = None
290297 else :
@@ -394,20 +401,10 @@ def extract(self, key: ItemKey) -> OutputItem:
394401 target_data = np .zeros ((0 , len (self .target_channels [stream_idx ])), dtype = np .float32 )
395402 preds_data = np .zeros ((0 , len (self .target_channels [stream_idx ])), dtype = np .float32 )
396403 else :
397- target_data = (
398- self .targets [offset_key .forecast_step ][stream_idx ][0 ][datapoints ]
399- .cpu ()
400- .detach ()
401- .numpy ()
402- )
403- preds_data = (
404- self .predictions [offset_key .forecast_step ][stream_idx ][0 ]
405- .transpose (1 , 0 )
406- .transpose (1 , 2 )[datapoints ]
407- .cpu ()
408- .detach ()
409- .numpy ()
410- )
404+ target_data = self .targets [offset_key .forecast_step ][stream_idx ][0 ][datapoints ]
405+ preds_data = self .predictions [offset_key .forecast_step ][stream_idx ][0 ].transpose (
406+ 1 , 2 , 0
407+ )[datapoints ]
411408
412409 data_coords = self ._extract_coordinates (stream_idx , offset_key , datapoints )
413410
@@ -423,6 +420,8 @@ def extract(self, key: ItemKey) -> OutputItem:
423420 else :
424421 source_dataset = None
425422
423+ assert is_ndarray (target_data ), f"Expected ndarray but got: { type (target_data )} "
424+ assert is_ndarray (preds_data ), f"Expected ndarray but got: { type (preds_data )} "
426425 return OutputItem (
427426 key = key ,
428427 source = source_dataset ,
@@ -501,10 +500,10 @@ def _extract_sources(self, sample, stream_idx, key):
501500 source_dataset = OutputDataset (
502501 "source" ,
503502 key ,
504- source .data ,
505- source .datetimes ,
506- source .coords ,
507- source .geoinfos ,
503+ np . asarray ( source .data ) ,
504+ np . asarray ( source .datetimes ) ,
505+ np . asarray ( source .coords ) ,
506+ np . asarray ( source .geoinfos ) ,
508507 channels ,
509508 geoinfo_channels ,
510509 )
0 commit comments