Skip to content

Commit 511f036

Browse files
authored
fix: associate output stream names with correct index (ecmwf#519)
* fix: associate output stream names with correct index * ruffed * fix: iteration over output items * address comments * fix: correctly index channels * fix stream indexing logic, add asserts * fix: extraction of data/coordinates for sources * fix assert
1 parent bb7e269 commit 511f036

File tree

2 files changed

+49
-25
lines changed

2 files changed

+49
-25
lines changed

packages/common/src/weathergen/common/io.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,8 @@ def forecast_steps(self) -> list[int]:
240240
class OutputBatchData:
241241
"""Provide convenient access to adapt existing output data structures."""
242242

243-
# sample, stream, tensor(datapoint, channel) => datapoints is accross all datasets per stream
243+
# sample, stream, tensor(datapoint, channel+coords)
244+
# => datapoints is accross all datasets per stream
244245
sources: list[list]
245246

246247
# fstep, stream, redundant dim (size 1), tensor(sample x datapoint, channel)
@@ -258,7 +259,8 @@ class OutputBatchData:
258259
# fstep, stream, redundant dim (size 1)
259260
targets_lens: list[list[list[int]]]
260261

261-
stream_names: list[str]
262+
# stream name: index into data (only streams in analysis_streams_output)
263+
streams: dict[str, int]
262264

263265
# stream, channel name
264266
channels: list[list[str]]
@@ -279,22 +281,23 @@ def forecast_steps(self):
279281

280282
def items(self) -> typing.Generator[OutputItem, None, None]:
281283
"""Iterate over possible output items"""
282-
filtered_streams = (stream for stream in self.stream_names if stream != "")
283284
# TODO: filter for empty items?
284-
for s, fo_s, fi_s in itertools.product(self.samples, self.forecast_steps, filtered_streams):
285+
for s, fo_s, fi_s in itertools.product(
286+
self.samples, self.forecast_steps, self.streams.keys()
287+
):
285288
yield self.extract(ItemKey(int(s), int(fo_s), fi_s))
286289

287290
def extract(self, key: ItemKey) -> OutputItem:
288291
"""Extract datasets from lists for one output item."""
289292
# adjust shifted values in ItemMeta
290293
sample = key.sample - self.sample_start
291294
forecast_step = key.forecast_step - self.forecast_offset
292-
stream_idx = self.stream_names.index(key.stream) # TODO: assure this is correct
295+
stream_idx = self.streams[key.stream]
293296
lens = self.targets_lens[forecast_step][stream_idx]
294297
start = sum(lens[:sample])
295298
n_samples = lens[sample]
296299

297-
_logger.info("extracting subset")
300+
_logger.info(f"extracting subset: {key}")
298301
_logger.info(
299302
f"sample: start:{self.sample_start} rel_idx:{sample} range:{start}-{start + n_samples}"
300303
)
@@ -331,18 +334,40 @@ def extract(self, key: ItemKey) -> OutputItem:
331334
channels = self.channels[stream_idx]
332335
geoinfo_channels = self.geoinfo_channels[stream_idx]
333336

337+
assert len(channels) == target_data.shape[1], (
338+
"Number of channel names does not align with data"
339+
)
340+
assert len(channels) == preds_data.shape[1], (
341+
"Number of channel names does not align with data"
342+
)
343+
334344
if key.with_source:
335345
source_data = self.sources[sample][stream_idx].cpu().detach().numpy()
346+
347+
# split data into coords, geoinfo, channels
348+
_source_coords = source_data[:, : -len(channels)]
349+
source_coords = _source_coords[:, :2]
350+
source_times = _source_coords[:, 2]
351+
source_geoinfo = _source_coords[:, 2 : -len(channels)]
352+
353+
# TODO asserts that times, coords, geoinfos should match?
354+
336355
source_dataset = OutputDataset(
337356
"source",
338357
key,
339-
source_data,
340-
times,
341-
coords,
342-
geoinfo,
358+
source_data[:, -len(channels) :],
359+
source_times,
360+
source_coords,
361+
source_geoinfo,
343362
channels,
344363
geoinfo_channels,
345364
)
365+
366+
_logger.info(f"source shape: {source_dataset.data.shape}")
367+
assert len(channels) == source_dataset.data.shape[1], (
368+
"Number of channel names does not align with data"
369+
)
370+
assert len(geoinfo_channels) == source_dataset.geoinfo.shape[1]
346371
else:
347372
source_dataset = None
348373

src/weathergen/utils/validation_io.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,36 +26,35 @@ def write_output(
2626
targets_times_all,
2727
targets_lens,
2828
):
29-
if cf.analysis_streams_output is None:
30-
output_stream_names = [stream.name for stream in cf.streams]
31-
_logger.info(f"Using all streams as output streams: {output_stream_names}")
32-
else:
33-
output_stream_names = [
34-
stream.name for stream in cf.streams if stream.name in cf.analysis_streams_output
35-
]
36-
_logger.info(f"Using output streams: {output_stream_names}")
37-
# TODO: streams anemoi `source`, `target` commented out???
29+
stream_names = [stream.name for stream in cf.streams]
30+
output_stream_names = cf.analysis_streams_output
31+
if output_stream_names is None:
32+
output_stream_names = stream_names
3833

39-
channels: list[list[str]] = [
40-
list(stream.val_target_channels)
41-
for stream in cf.streams
42-
if stream.name in output_stream_names
43-
]
34+
output_streams = {name: stream_names.index(name) for name in output_stream_names}
35+
36+
_logger.info(f"Using output streams: {output_streams} from streams: {stream_names}")
37+
38+
channels: list[list[str]] = [list(stream.val_target_channels) for stream in cf.streams]
4439

4540
geoinfo_channels = [[] for _ in cf.streams] # TODO obtain channels
4641

4742
# assume: is batch size guarnteed and constant:
4843
# => calculate global sample indices for this batch by offsetting by sample_start
4944
sample_start = batch_idx * cf.batch_size_validation_per_gpu
5045

46+
assert len(stream_names) == len(targets_all[0]), "data does not match number of streams"
47+
assert len(stream_names) == len(preds_all[0]), "data does not match number of streams"
48+
assert len(stream_names) == len(sources[0]), "data does not match number of streams"
49+
5150
data = io.OutputBatchData(
5251
sources,
5352
targets_all,
5453
preds_all,
5554
targets_coords_all,
5655
targets_times_all,
5756
targets_lens,
58-
output_stream_names,
57+
output_streams,
5958
channels,
6059
geoinfo_channels,
6160
sample_start,

0 commit comments

Comments
 (0)