@@ -285,6 +285,15 @@ def forecast_steps(self) -> list[int]:
285285 return list (example_stream .group_keys ())
286286
287287
288+ @dataclasses .dataclass
289+ class DataCoordinates :
290+ times : typing .Any
291+ coords : typing .Any
292+ geoinfo : typing .Any
293+ channels : typing .Any
294+ geoinfo_channels : typing .Any
295+
296+
288297@dataclasses .dataclass
289298class OutputBatchData :
290299 """Provide convenient access to adapt existing output data structures."""
@@ -312,7 +321,8 @@ class OutputBatchData:
312321 streams : dict [str , int ]
313322
314323 # stream, channel name
315- channels : list [list [str ]]
324+ target_channels : list [list [str ]]
325+ source_channels : list [list [str ]]
316326 geoinfo_channels : list [list [str ]]
317327
318328 sample_start : int
@@ -338,114 +348,131 @@ def items(self) -> typing.Generator[OutputItem, None, None]:
338348
339349 def extract (self , key : ItemKey ) -> OutputItem :
340350 """Extract datasets from lists for one output item."""
341- # adjust shifted values in ItemMeta
342- sample = key .sample - self .sample_start
343- forecast_step = key .forecast_step - self .forecast_offset
351+ _logger .debug (f"extracting subset: { key } " )
352+ offset_key = self ._offset_key (key )
344353 stream_idx = self .streams [key .stream ]
345- lens = self .targets_lens [forecast_step ][stream_idx ]
346-
347- # empty target/prediction
348- if len (lens ) == 0 :
349- start = 0
350- n_samples = 0
351- else :
352- start = sum (lens [:sample ])
353- n_samples = lens [sample ]
354+ datapoints = self ._get_datapoints_per_sample (offset_key , stream_idx )
354355
355- _logger .debug (f"extracting subset: { key } " )
356- _logger .debug (
357- f"sample: start:{ self .sample_start } rel_idx:{ sample } range:{ start } -{ start + n_samples } "
358- )
359356 _logger .debug (
360- f"forecast_step: { key .forecast_step } = { forecast_step } (rel_step) + "
357+ f"forecast_step: { key .forecast_step } = { offset_key . forecast_step } (rel_step) + "
361358 + f"{ self .forecast_offset } (forecast_offset)"
362359 )
363360 _logger .debug (f"stream: { key .stream } with index: { stream_idx } " )
364361
365- datapoints = slice (start , start + n_samples )
366-
367- if n_samples == 0 :
362+ if (datapoints .stop - datapoints .start ) == 0 :
368363 target_data = np .zeros ((0 , len (self .channels [stream_idx ])), dtype = np .float32 )
369364 preds_data = np .zeros ((0 , len (self .channels [stream_idx ])), dtype = np .float32 )
370365 else :
371366 target_data = (
372- self .targets [forecast_step ][stream_idx ][0 ][datapoints ].cpu ().detach ().numpy ()
367+ self .targets [offset_key .forecast_step ][stream_idx ][0 ][datapoints ]
368+ .cpu ()
369+ .detach ()
370+ .numpy ()
373371 )
374372 preds_data = (
375- self .predictions [forecast_step ][stream_idx ][0 ]
373+ self .predictions [offset_key . forecast_step ][stream_idx ][0 ]
376374 .transpose (1 , 0 )
377375 .transpose (1 , 2 )[datapoints ]
378376 .cpu ()
379377 .detach ()
380378 .numpy ()
381379 )
382380
383- _coords = self .targets_coords [forecast_step ][stream_idx ][datapoints ].numpy ()
384- coords = _coords [..., :2 ] # first two columns are lat,lon
385- geoinfo = _coords [..., 2 :] # the rest is geoinfo => potentially empty
386- if geoinfo .size > 0 : # TODO: set geoinfo to be empty for now
387- geoinfo = np .empty ((geoinfo .shape [0 ], 0 ))
388- _logger .warning (
389- "geoinformation channels are not implemented yet."
390- + "will be truncated to be of size 0."
391- )
392- times = self .targets_times [forecast_step ][stream_idx ][
393- datapoints
394- ] # make conversion to datetime64[ns] here?
395- channels = self .channels [stream_idx ]
396- geoinfo_channels = self .geoinfo_channels [stream_idx ]
381+ data_coords = self ._extract_coordinates (stream_idx , offset_key , datapoints )
397382
398- assert len (channels ) == target_data .shape [1 ], (
399- "Number of channel names does not align with data"
383+ assert len (data_coords . channels ) == target_data .shape [1 ], (
384+ "Number of channel names does not align with target data. "
400385 )
401- assert len (channels ) == preds_data .shape [1 ], (
402- "Number of channel names does not align with data"
386+ assert len (data_coords . channels ) == preds_data .shape [1 ], (
387+ "Number of channel names does not align with prediction data. "
403388 )
404389
405390 if key .with_source :
406- source = self .sources [sample ][stream_idx ]
407-
408- # currently fails since no separate channels for source/target implemented
409- # assert source.data.shape[1] == len(channels), (
410- # "Number of channel names does not align with data"
411- # )
412-
413- source_dataset = OutputDataset (
414- "source" ,
415- key ,
416- source .data ,
417- source .datetimes ,
418- source .coords ,
419- source .geoinfos ,
420- channels ,
421- geoinfo_channels ,
422- )
423-
424- _logger .debug (f"source shape: { source_dataset .data .shape } " )
391+ source_dataset = self ._extract_sources (offset_key .sample , stream_idx , key )
425392 else :
426393 source_dataset = None
427394
428395 return OutputItem (
429396 key = key ,
430397 source = source_dataset ,
431- target = OutputDataset (
432- "target" ,
433- key ,
434- target_data ,
435- times ,
436- coords ,
437- geoinfo ,
438- channels ,
439- geoinfo_channels ,
440- ),
398+ target = OutputDataset ("target" , key , target_data , ** dataclasses .asdict (data_coords )),
441399 prediction = OutputDataset (
442- "prediction" ,
443- key ,
444- preds_data ,
445- times ,
446- coords ,
447- geoinfo ,
448- channels ,
449- geoinfo_channels ,
400+ "prediction" , key , preds_data , ** dataclasses .asdict (data_coords )
450401 ),
451402 )
403+
404+ def _get_datapoints_per_sample (self , offset_key , stream_idx ):
405+ lens = self .targets_lens [offset_key .forecast_step ][stream_idx ]
406+
407+ # empty target/prediction
408+ if len (lens ) == 0 :
409+ start = 0
410+ n_samples = 0
411+ else :
412+ start = sum (lens [: offset_key .sample ])
413+ n_samples = lens [offset_key .sample ]
414+
415+ _logger .debug (
416+ f"sample: start:{ self .sample_start } rel_idx:{ offset_key .sample } "
417+ + f"range:{ start } -{ start + n_samples } "
418+ )
419+
420+ return slice (start , start + n_samples )
421+
422+ def _offset_key (self , key : ItemKey ):
423+ """
424+ Correct indices in key to be useable for data extraction.
425+
426+ `key` contains indices that are adjusted to have better output semantics.
427+ To be useable in extraction these have to be adjusted to bridge the differences
428+ compared to the semantics of the data.
429+ - `sample` is adjusted from a global continous index to a per batch index
430+ - `forecast_step` is adjusted from including `forecast_offset` to indexing
431+ the data (always starts at 0)
432+ """
433+ return ItemKey (
434+ key .sample - self .sample_start , key .forecast_step - self .forecast_offset , key .stream
435+ )
436+
437+ def _extract_coordinates (self , stream_idx , offset_key , datapoints ) -> DataCoordinates :
438+ _coords = self .targets_coords [offset_key .forecast_step ][stream_idx ][datapoints ].numpy ()
439+ coords = _coords [:, :2 ] # first two columns are lat,lon
440+ geoinfo = _coords [:, 2 :] # the rest is geoinfo => potentially empty
441+ if geoinfo .size > 0 : # TODO: set geoinfo to be empty for now
442+ geoinfo = np .empty ((geoinfo .shape [0 ], 0 ))
443+ _logger .warning (
444+ "geoinformation channels are not implemented yet."
445+ + "will be truncated to be of size 0."
446+ )
447+ times = self .targets_times [offset_key .forecast_step ][stream_idx ][
448+ datapoints
449+ ] # make conversion to datetime64[ns] here?
450+ channels = self .target_channels [stream_idx ]
451+ geoinfo_channels = self .geoinfo_channels [stream_idx ]
452+
453+ return DataCoordinates (times , coords , geoinfo , channels , geoinfo_channels )
454+
455+ def _extract_sources (self , sample , stream_idx , key ):
456+ channels = self .source_channels [stream_idx ]
457+ geoinfo_channels = self .geoinfo_channels [stream_idx ]
458+
459+ source = self .sources [sample ][stream_idx ]
460+
461+ assert source .data .shape [1 ] == len (channels ), (
462+ "Number of source channel names does not align with source data"
463+ )
464+
465+ source_dataset = OutputDataset (
466+ "source" ,
467+ key ,
468+ source .data ,
469+ source .datetimes ,
470+ source .coords ,
471+ source .geoinfos ,
472+ channels ,
473+ geoinfo_channels ,
474+ )
475+
476+ _logger .debug (f"source shape: { source_dataset .data .shape } " )
477+
478+ return source_dataset
0 commit comments