Skip to content

Commit 59c0d29

Browse files
authored
Sgrasse/develop/issue 616 (ecmwf#648)
* encapsulate extraction of source data * bundle offseting of key attributes * consolidate calculation of datapoints indices into method * encapsulate extraction of coordinate axis in function. * replace attribute `channels` by `target_channels` and `source_channels` * ruffed * ruffed * fixes * address michas comments * reactivate assert * fix typo / renaming * small fix * uncomment source_n_empty and target_n_empty unused variables
1 parent f17c1c7 commit 59c0d29

File tree

3 files changed

+111
-82
lines changed

3 files changed

+111
-82
lines changed

packages/common/src/weathergen/common/io.py

Lines changed: 105 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,15 @@ def forecast_steps(self) -> list[int]:
285285
return list(example_stream.group_keys())
286286

287287

288+
@dataclasses.dataclass
289+
class DataCoordinates:
290+
times: typing.Any
291+
coords: typing.Any
292+
geoinfo: typing.Any
293+
channels: typing.Any
294+
geoinfo_channels: typing.Any
295+
296+
288297
@dataclasses.dataclass
289298
class OutputBatchData:
290299
"""Provide convenient access to adapt existing output data structures."""
@@ -312,7 +321,8 @@ class OutputBatchData:
312321
streams: dict[str, int]
313322

314323
# stream, channel name
315-
channels: list[list[str]]
324+
target_channels: list[list[str]]
325+
source_channels: list[list[str]]
316326
geoinfo_channels: list[list[str]]
317327

318328
sample_start: int
@@ -338,114 +348,131 @@ def items(self) -> typing.Generator[OutputItem, None, None]:
338348

339349
def extract(self, key: ItemKey) -> OutputItem:
340350
"""Extract datasets from lists for one output item."""
341-
# adjust shifted values in ItemMeta
342-
sample = key.sample - self.sample_start
343-
forecast_step = key.forecast_step - self.forecast_offset
351+
_logger.debug(f"extracting subset: {key}")
352+
offset_key = self._offset_key(key)
344353
stream_idx = self.streams[key.stream]
345-
lens = self.targets_lens[forecast_step][stream_idx]
346-
347-
# empty target/prediction
348-
if len(lens) == 0:
349-
start = 0
350-
n_samples = 0
351-
else:
352-
start = sum(lens[:sample])
353-
n_samples = lens[sample]
354+
datapoints = self._get_datapoints_per_sample(offset_key, stream_idx)
354355

355-
_logger.debug(f"extracting subset: {key}")
356-
_logger.debug(
357-
f"sample: start:{self.sample_start} rel_idx:{sample} range:{start}-{start + n_samples}"
358-
)
359356
_logger.debug(
360-
f"forecast_step: {key.forecast_step} = {forecast_step} (rel_step) + "
357+
f"forecast_step: {key.forecast_step} = {offset_key.forecast_step} (rel_step) + "
361358
+ f"{self.forecast_offset} (forecast_offset)"
362359
)
363360
_logger.debug(f"stream: {key.stream} with index: {stream_idx}")
364361

365-
datapoints = slice(start, start + n_samples)
366-
367-
if n_samples == 0:
362+
if (datapoints.stop - datapoints.start) == 0:
368363
target_data = np.zeros((0, len(self.channels[stream_idx])), dtype=np.float32)
369364
preds_data = np.zeros((0, len(self.channels[stream_idx])), dtype=np.float32)
370365
else:
371366
target_data = (
372-
self.targets[forecast_step][stream_idx][0][datapoints].cpu().detach().numpy()
367+
self.targets[offset_key.forecast_step][stream_idx][0][datapoints]
368+
.cpu()
369+
.detach()
370+
.numpy()
373371
)
374372
preds_data = (
375-
self.predictions[forecast_step][stream_idx][0]
373+
self.predictions[offset_key.forecast_step][stream_idx][0]
376374
.transpose(1, 0)
377375
.transpose(1, 2)[datapoints]
378376
.cpu()
379377
.detach()
380378
.numpy()
381379
)
382380

383-
_coords = self.targets_coords[forecast_step][stream_idx][datapoints].numpy()
384-
coords = _coords[..., :2] # first two columns are lat,lon
385-
geoinfo = _coords[..., 2:] # the rest is geoinfo => potentially empty
386-
if geoinfo.size > 0: # TODO: set geoinfo to be empty for now
387-
geoinfo = np.empty((geoinfo.shape[0], 0))
388-
_logger.warning(
389-
"geoinformation channels are not implemented yet."
390-
+ "will be truncated to be of size 0."
391-
)
392-
times = self.targets_times[forecast_step][stream_idx][
393-
datapoints
394-
] # make conversion to datetime64[ns] here?
395-
channels = self.channels[stream_idx]
396-
geoinfo_channels = self.geoinfo_channels[stream_idx]
381+
data_coords = self._extract_coordinates(stream_idx, offset_key, datapoints)
397382

398-
assert len(channels) == target_data.shape[1], (
399-
"Number of channel names does not align with data"
383+
assert len(data_coords.channels) == target_data.shape[1], (
384+
"Number of channel names does not align with target data."
400385
)
401-
assert len(channels) == preds_data.shape[1], (
402-
"Number of channel names does not align with data"
386+
assert len(data_coords.channels) == preds_data.shape[1], (
387+
"Number of channel names does not align with prediction data."
403388
)
404389

405390
if key.with_source:
406-
source = self.sources[sample][stream_idx]
407-
408-
# currently fails since no separate channels for source/target implemented
409-
# assert source.data.shape[1] == len(channels), (
410-
# "Number of channel names does not align with data"
411-
# )
412-
413-
source_dataset = OutputDataset(
414-
"source",
415-
key,
416-
source.data,
417-
source.datetimes,
418-
source.coords,
419-
source.geoinfos,
420-
channels,
421-
geoinfo_channels,
422-
)
423-
424-
_logger.debug(f"source shape: {source_dataset.data.shape}")
391+
source_dataset = self._extract_sources(offset_key.sample, stream_idx, key)
425392
else:
426393
source_dataset = None
427394

428395
return OutputItem(
429396
key=key,
430397
source=source_dataset,
431-
target=OutputDataset(
432-
"target",
433-
key,
434-
target_data,
435-
times,
436-
coords,
437-
geoinfo,
438-
channels,
439-
geoinfo_channels,
440-
),
398+
target=OutputDataset("target", key, target_data, **dataclasses.asdict(data_coords)),
441399
prediction=OutputDataset(
442-
"prediction",
443-
key,
444-
preds_data,
445-
times,
446-
coords,
447-
geoinfo,
448-
channels,
449-
geoinfo_channels,
400+
"prediction", key, preds_data, **dataclasses.asdict(data_coords)
450401
),
451402
)
403+
404+
def _get_datapoints_per_sample(self, offset_key, stream_idx):
405+
lens = self.targets_lens[offset_key.forecast_step][stream_idx]
406+
407+
# empty target/prediction
408+
if len(lens) == 0:
409+
start = 0
410+
n_samples = 0
411+
else:
412+
start = sum(lens[: offset_key.sample])
413+
n_samples = lens[offset_key.sample]
414+
415+
_logger.debug(
416+
f"sample: start:{self.sample_start} rel_idx:{offset_key.sample}"
417+
+ f"range:{start}-{start + n_samples}"
418+
)
419+
420+
return slice(start, start + n_samples)
421+
422+
def _offset_key(self, key: ItemKey):
423+
"""
424+
Correct indices in key to be useable for data extraction.
425+
426+
`key` contains indices that are adjusted to have better output semantics.
427+
To be useable in extraction these have to be adjusted to bridge the differences
428+
compared to the semantics of the data.
429+
- `sample` is adjusted from a global continous index to a per batch index
430+
- `forecast_step` is adjusted from including `forecast_offset` to indexing
431+
the data (always starts at 0)
432+
"""
433+
return ItemKey(
434+
key.sample - self.sample_start, key.forecast_step - self.forecast_offset, key.stream
435+
)
436+
437+
def _extract_coordinates(self, stream_idx, offset_key, datapoints) -> DataCoordinates:
438+
_coords = self.targets_coords[offset_key.forecast_step][stream_idx][datapoints].numpy()
439+
coords = _coords[:, :2] # first two columns are lat,lon
440+
geoinfo = _coords[:, 2:] # the rest is geoinfo => potentially empty
441+
if geoinfo.size > 0: # TODO: set geoinfo to be empty for now
442+
geoinfo = np.empty((geoinfo.shape[0], 0))
443+
_logger.warning(
444+
"geoinformation channels are not implemented yet."
445+
+ "will be truncated to be of size 0."
446+
)
447+
times = self.targets_times[offset_key.forecast_step][stream_idx][
448+
datapoints
449+
] # make conversion to datetime64[ns] here?
450+
channels = self.target_channels[stream_idx]
451+
geoinfo_channels = self.geoinfo_channels[stream_idx]
452+
453+
return DataCoordinates(times, coords, geoinfo, channels, geoinfo_channels)
454+
455+
def _extract_sources(self, sample, stream_idx, key):
456+
channels = self.source_channels[stream_idx]
457+
geoinfo_channels = self.geoinfo_channels[stream_idx]
458+
459+
source = self.sources[sample][stream_idx]
460+
461+
assert source.data.shape[1] == len(channels), (
462+
"Number of source channel names does not align with source data"
463+
)
464+
465+
source_dataset = OutputDataset(
466+
"source",
467+
key,
468+
source.data,
469+
source.datetimes,
470+
source.coords,
471+
source.geoinfos,
472+
channels,
473+
geoinfo_channels,
474+
)
475+
476+
_logger.debug(f"source shape: {source_dataset.data.shape}")
477+
478+
return source_dataset

src/weathergen/datasets/data_reader_obs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ def __init__(self, tw_handler: TimeWindowHandler, filename: Path, stream_info: d
5252
t_chs = stream_info.get("target")
5353
t_chs_exclude = stream_info.get("target_exclude", [])
5454

55-
source_n_empty = len(s_chs) > 0 if s_chs is not None else True
55+
# source_n_empty = len(s_chs) > 0 if s_chs is not None else True
5656
# assert source_n_empty, "source is empty; at least one channels must be present."
57-
target_n_empty = len(t_chs) > 0 if t_chs is not None else True
57+
# target_n_empty = len(t_chs) > 0 if t_chs is not None else True
5858
# assert target_n_empty, "target is empty; at least one channels must be present."
5959

6060
self.source_channels = self.select_channels(data_colnames, s_chs, s_chs_exclude)

src/weathergen/utils/validation_io.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ def write_output(
3535

3636
_logger.debug(f"Using output streams: {output_streams} from streams: {stream_names}")
3737

38-
channels: list[list[str]] = [list(stream.val_target_channels) for stream in cf.streams]
38+
target_channels: list[list[str]] = [list(stream.val_target_channels) for stream in cf.streams]
39+
source_channels: list[list[str]] = [list(stream.val_source_channels) for stream in cf.streams]
3940

4041
geoinfo_channels = [[] for _ in cf.streams] # TODO obtain channels
4142

@@ -55,7 +56,8 @@ def write_output(
5556
targets_times_all,
5657
targets_lens,
5758
output_streams,
58-
channels,
59+
target_channels,
60+
source_channels,
5961
geoinfo_channels,
6062
sample_start,
6163
cf.forecast_offset,

0 commit comments

Comments
 (0)