77# granted to it by virtue of its status as an intergovernmental organisation
88# nor does it submit to any jurisdiction.
99
10- import numpy as np
11- import torch
12- import math
1310import datetime
14- from copy import deepcopy
1511import logging
16- import time
17- import code
18- import os
19- import yaml
2012
13+ import numpy as np
2114import pandas as pd
15+ import torch
2216
23- from weathergen .datasets .obs_dataset import ObsDataset
2417from weathergen .datasets .anemoi_dataset import AnemoiDataset
25- from weathergen .datasets .normalizer import DataNormalizer
2618from weathergen .datasets .batchifyer import Batchifyer
19+ from weathergen .datasets .normalizer import DataNormalizer
20+ from weathergen .datasets .obs_dataset import ObsDataset
2721from weathergen .datasets .utils import merge_cells
28-
2922from weathergen .utils .logger import logger
3023
3124
@@ -69,7 +62,7 @@ def __init__(
6962 self .len_hrs = len_hrs
7063 self .step_hrs = step_hrs
7164
72- fc_policy_seq = "sequential" == forecast_policy or "sequential_random" == forecast_policy
65+ fc_policy_seq = forecast_policy == "sequential" or forecast_policy == "sequential_random"
7366 assert forecast_steps >= 0 if not fc_policy_seq else True
7467 self .forecast_delta_hrs = forecast_delta_hrs if forecast_delta_hrs > 0 else self .len_hrs
7568 self .forecast_steps = np .array (
@@ -111,7 +104,7 @@ def __init__(
111104 # the processing here is not natural but a workaround to various inconsistencies in the
112105 # current datasets
113106 data_idxs = [
114- i for i , cn in enumerate (ds .selected_colnames [do :]) if "obsvalue_" == cn [:9 ]
107+ i for i , cn in enumerate (ds .selected_colnames [do :]) if cn [:9 ] == "obsvalue_"
115108 ]
116109 mask = np .ones (len (ds .selected_colnames [do :]), dtype = np .int32 ).astype (bool )
117110 mask [data_idxs ] = False
@@ -272,7 +265,7 @@ def __iter__(self):
272265 # idx_raw is used to index into the dataset; the decoupling is needed
273266 # since there are empty batches
274267 idx_raw = iter_start
275- for i , bidx in enumerate (range (iter_start , iter_end , self .batch_size )):
268+ for i , _bidx in enumerate (range (iter_start , iter_end , self .batch_size )):
276269 # targets, targets_coords, targets_idxs = [], [], [],
277270 tcs , tcs_lens , target_tokens , source_tokens_cells , source_tokens_lens = (
278271 [],
@@ -314,7 +307,7 @@ def __iter__(self):
314307 c_source_raw = []
315308
316309 for obs_id , (stream_info , stream_dsn , stream_idxs ) in enumerate (
317- zip (self .streams , self .obs_datasets_norm , self .obs_datasets_idxs )
310+ zip (self .streams , self .obs_datasets_norm , self .obs_datasets_idxs , strict = False )
318311 ):
319312 s_tcs = []
320313 s_tcs_lens = []
@@ -326,17 +319,17 @@ def __iter__(self):
326319 s_source_raw = []
327320
328321 token_size = stream_info ["token_size" ]
329- grid = (
330- stream_info ["gridded_output" ] if "gridded_output" in stream_info else None
331- )
332- grid_info = (
333- stream_info ["gridded_output_info" ]
334- if "gridded_output_info" in stream_info
335- else None
336- )
322+ # grid = (
323+ # stream_info["gridded_output"] if "gridded_output" in stream_info else None
324+ # )
325+ # grid_info = (
326+ # stream_info["gridded_output_info"]
327+ # if "gridded_output_info" in stream_info
328+ # else None
329+ # )
337330
338331 for i_source , ((ds , normalizer , do ), s_idxs ) in enumerate (
339- zip (stream_dsn , stream_idxs )
332+ zip (stream_dsn , stream_idxs , strict = False )
340333 ):
341334 # source window (of potentially multi-step length)
342335 (source1 , times1 ) = ds [idx ]
@@ -417,7 +410,7 @@ def __iter__(self):
417410 for fstep in range (forecast_dt + 1 ):
418411 # collect all streams
419412 for i_source , ((ds , normalizer , do ), s_idxs ) in enumerate (
420- zip (stream_dsn , stream_idxs )
413+ zip (stream_dsn , stream_idxs , strict = False )
421414 ):
422415 (source2 , times2 ) = ds [idx + step_forecast_dt ]
423416
@@ -534,15 +527,17 @@ def __iter__(self):
534527 idxs = torch .cat (
535528 [
536529 torch .arange (o , o + l , dtype = torch .int64 )
537- for o , l in zip (offsets , source_tokens_lens [ib , itype ])
530+ for o , l in zip (offsets , source_tokens_lens [ib , itype ], strict = False )
538531 ]
539532 )
540533 idxs_embed [- 1 ] += [idxs .unsqueeze (1 )]
541534 idxs_embed_pe [- 1 ] += [
542535 torch .cat (
543536 [
544537 torch .arange (o , o + l , dtype = torch .int32 )
545- for o , l in zip (offsets_pe , source_tokens_lens [ib ][itype ])
538+ for o , l in zip (
539+ offsets_pe , source_tokens_lens [ib ][itype ], strict = False
540+ )
546541 ]
547542 )
548543 ]
0 commit comments