@@ -240,7 +240,8 @@ def forecast_steps(self) -> list[int]:
240240class OutputBatchData :
241241 """Provide convenient access to adapt existing output data structures."""
242242
243- # sample, stream, tensor(datapoint, channel) => datapoints is accross all datasets per stream
243+ # sample, stream, tensor(datapoint, channel+coords)
244+ # => datapoints is accross all datasets per stream
244245 sources : list [list ]
245246
246247 # fstep, stream, redundant dim (size 1), tensor(sample x datapoint, channel)
@@ -258,7 +259,8 @@ class OutputBatchData:
258259 # fstep, stream, redundant dim (size 1)
259260 targets_lens : list [list [list [int ]]]
260261
261- stream_names : list [str ]
262+ # stream name: index into data (only streams in analysis_streams_output)
263+ streams : dict [str , int ]
262264
263265 # stream, channel name
264266 channels : list [list [str ]]
@@ -279,22 +281,23 @@ def forecast_steps(self):
279281
280282 def items (self ) -> typing .Generator [OutputItem , None , None ]:
281283 """Iterate over possible output items"""
282- filtered_streams = (stream for stream in self .stream_names if stream != "" )
283284 # TODO: filter for empty items?
284- for s , fo_s , fi_s in itertools .product (self .samples , self .forecast_steps , filtered_streams ):
285+ for s , fo_s , fi_s in itertools .product (
286+ self .samples , self .forecast_steps , self .streams .keys ()
287+ ):
285288 yield self .extract (ItemKey (int (s ), int (fo_s ), fi_s ))
286289
287290 def extract (self , key : ItemKey ) -> OutputItem :
288291 """Extract datasets from lists for one output item."""
289292 # adjust shifted values in ItemMeta
290293 sample = key .sample - self .sample_start
291294 forecast_step = key .forecast_step - self .forecast_offset
292- stream_idx = self .stream_names . index ( key .stream ) # TODO: assure this is correct
295+ stream_idx = self .streams [ key .stream ]
293296 lens = self .targets_lens [forecast_step ][stream_idx ]
294297 start = sum (lens [:sample ])
295298 n_samples = lens [sample ]
296299
297- _logger .info ("extracting subset" )
300+ _logger .info (f "extracting subset: { key } " )
298301 _logger .info (
299302 f"sample: start:{ self .sample_start } rel_idx:{ sample } range:{ start } -{ start + n_samples } "
300303 )
@@ -331,18 +334,40 @@ def extract(self, key: ItemKey) -> OutputItem:
331334 channels = self .channels [stream_idx ]
332335 geoinfo_channels = self .geoinfo_channels [stream_idx ]
333336
337+ assert len (channels ) == target_data .shape [1 ], (
338+ "Number of channel names does not align with data"
339+ )
340+ assert len (channels ) == preds_data .shape [1 ], (
341+ "Number of channel names does not align with data"
342+ )
343+
334344 if key .with_source :
335345 source_data = self .sources [sample ][stream_idx ].cpu ().detach ().numpy ()
346+
347+ # split data into coords, geoinfo, channels
348+ _source_coords = source_data [:, : - len (channels )]
349+ source_coords = _source_coords [:, :2 ]
350+ source_times = _source_coords [:, 2 ]
351+ source_geoinfo = _source_coords [:, 2 : - len (channels )]
352+
353+ # TODO asserts that times, coords, geoinfos should match?
354+
336355 source_dataset = OutputDataset (
337356 "source" ,
338357 key ,
339- source_data ,
340- times ,
341- coords ,
342- geoinfo ,
358+ source_data [:, - len ( channels ) :] ,
359+ source_times ,
360+ source_coords ,
361+ source_geoinfo ,
343362 channels ,
344363 geoinfo_channels ,
345364 )
365+
366+ _logger .info (f"source shape: { source_dataset .data .shape } " )
367+ assert len (channels ) == source_dataset .data .shape [1 ], (
368+ "Number of channel names does not align with data"
369+ )
370+ assert len (geoinfo_channels ) == source_dataset .geoinfo .shape [1 ]
346371 else :
347372 source_dataset = None
348373
0 commit comments