@@ -46,6 +46,7 @@ def __init__(self, cf, start_date, end_date, batch_size, samples_per_epoch, shuf
4646 self .forecast_delta_hrs = (
4747 cf .forecast_delta_hrs if cf .forecast_delta_hrs > 0 else self .len_hrs
4848 )
49+ assert self .forecast_delta_hrs == self .len_hrs , "Only supported option at the moment"
4950 self .forecast_steps = np .array (
5051 [cf .forecast_steps ] if type (cf .forecast_steps ) == int else cf .forecast_steps
5152 )
@@ -263,15 +264,10 @@ def __iter__(self):
263264 idx = self .perms [idx_raw % self .perms .shape [0 ]]
264265 idx_raw += 1
265266
266- step_dt = self .len_hrs // self .step_hrs
267- step_forecast_dt = (
268- step_dt + (self .forecast_delta_hrs * forecast_dt ) // self .step_hrs
269- )
270-
271267 # TODO: this has to be independent of specific datasets
272- time_win1 , time_win2 = (
268+ time_win1 , _ = (
273269 self .streams_datasets [- 1 ][0 ].time_window (idx ),
274- self .streams_datasets [- 1 ][0 ].time_window (idx + step_forecast_dt ),
270+ self .streams_datasets [- 1 ][0 ].time_window (idx + self . len_hrs // self . step_hrs ),
275271 )
276272
277273 streams_data = []
@@ -288,7 +284,7 @@ def __iter__(self):
288284 (coords , geoinfos , source , times ) = ds .get_source (idx )
289285 for it in range (1 , self .input_window_steps ):
290286 (coords0 , geoinfos0 , source0 , times0 ) = ds .get_source (
291- idx - it * step_dt
287+ idx - it * self . len_hrs
292288 )
293289 coords = np .concatenate ([coords0 , coords ], 0 )
294290 geoinfos = np .concatenate ([geoinfos0 , geoinfos ], 0 )
@@ -325,9 +321,15 @@ def __iter__(self):
325321 for fstep in range (forecast_dt + 1 ):
326322 # collect all sources
327323 for _ , ds in enumerate (stream_ds ):
328- ( coords , geoinfos , target , times ) = ds . get_target (
329- idx + step_forecast_dt
324+ step_forecast_dt = (
325+ idx + ( self . forecast_delta_hrs * fstep ) // self . step_hrs
330326 )
327+ _ , time_win2 = (
328+ self .streams_datasets [- 1 ][0 ].time_window (idx ),
329+ self .streams_datasets [- 1 ][0 ].time_window (step_forecast_dt ),
330+ )
331+
332+ (coords , geoinfos , target , times ) = ds .get_target (step_forecast_dt )
331333
332334 if target .shape [0 ] == 0 :
333335 stream_data .add_empty_target (fstep )
0 commit comments