Skip to content

Commit 13321c7

Browse files
authored
Fixing bugs in forecasting (ecmwf#137)
- Data loader didn't load consecutive time steps - Problem in internal roll-out with model parameter passing Closes ecmwf#134
1 parent 8dba541 commit 13321c7

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

src/weathergen/datasets/multi_stream_data_sampler.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(self, cf, start_date, end_date, batch_size, samples_per_epoch, shuf
4646
self.forecast_delta_hrs = (
4747
cf.forecast_delta_hrs if cf.forecast_delta_hrs > 0 else self.len_hrs
4848
)
49+
assert self.forecast_delta_hrs == self.len_hrs, "Only supported option at the moment"
4950
self.forecast_steps = np.array(
5051
[cf.forecast_steps] if type(cf.forecast_steps) == int else cf.forecast_steps
5152
)
@@ -263,15 +264,10 @@ def __iter__(self):
263264
idx = self.perms[idx_raw % self.perms.shape[0]]
264265
idx_raw += 1
265266

266-
step_dt = self.len_hrs // self.step_hrs
267-
step_forecast_dt = (
268-
step_dt + (self.forecast_delta_hrs * forecast_dt) // self.step_hrs
269-
)
270-
271267
# TODO: this has to be independent of specific datasets
272-
time_win1, time_win2 = (
268+
time_win1, _ = (
273269
self.streams_datasets[-1][0].time_window(idx),
274-
self.streams_datasets[-1][0].time_window(idx + step_forecast_dt),
270+
self.streams_datasets[-1][0].time_window(idx + self.len_hrs // self.step_hrs),
275271
)
276272

277273
streams_data = []
@@ -288,7 +284,7 @@ def __iter__(self):
288284
(coords, geoinfos, source, times) = ds.get_source(idx)
289285
for it in range(1, self.input_window_steps):
290286
(coords0, geoinfos0, source0, times0) = ds.get_source(
291-
idx - it * step_dt
287+
idx - it * self.len_hrs
292288
)
293289
coords = np.concatenate([coords0, coords], 0)
294290
geoinfos = np.concatenate([geoinfos0, geoinfos], 0)
@@ -325,9 +321,15 @@ def __iter__(self):
325321
for fstep in range(forecast_dt + 1):
326322
# collect all sources
327323
for _, ds in enumerate(stream_ds):
328-
(coords, geoinfos, target, times) = ds.get_target(
329-
idx + step_forecast_dt
324+
step_forecast_dt = (
325+
idx + (self.forecast_delta_hrs * fstep) // self.step_hrs
330326
)
327+
_, time_win2 = (
328+
self.streams_datasets[-1][0].time_window(idx),
329+
self.streams_datasets[-1][0].time_window(step_forecast_dt),
330+
)
331+
332+
(coords, geoinfos, target, times) = ds.get_target(step_forecast_dt)
331333

332334
if target.shape[0] == 0:
333335
stream_data.add_empty_target(fstep)

src/weathergen/model/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,7 @@ def forward(self, model_params, batch, forecast_steps):
585585
# prediction
586586
preds_all += [
587587
self.predict(
588+
model_params,
588589
forecast_steps,
589590
tokens,
590591
streams_data,

0 commit comments

Comments
 (0)