88# nor does it submit to any jurisdiction.
99
1010import datetime
11+ import logging
1112
1213import numpy as np
1314from anemoi .datasets import open_dataset
1415
16+ _logger = logging .getLogger (__name__ )
17+
1518
1619class AnemoiDataset :
1720 "Wrapper for Anemoi dataset"
@@ -30,26 +33,26 @@ def __init__(
3033 assert len_hrs == step_hrs , "Currently only step_hrs=len_hrs is supported"
3134
3235 # open dataset to peak that it is compatible with requested parameters
33- self . ds = open_dataset (filename )
36+ ds = open_dataset (filename )
3437
3538 # check that start and end time are within the dataset time range
3639
37- ds_dt_start = self . ds .dates [0 ]
38- ds_dt_end = self . ds .dates [- 1 ]
40+ ds_dt_start = ds .dates [0 ]
41+ ds_dt_end = ds .dates [- 1 ]
3942
4043 format_str = "%Y%m%d%H%M%S"
4144 dt_start = datetime .datetime .strptime (str (start ), format_str )
4245 dt_end = datetime .datetime .strptime (str (end ), format_str )
4346
4447 # TODO, TODO, TODO: we need proper alignment for the case where self.ds.frequency
4548 # is not a multile of len_hrs
46- self .num_steps_per_window = int ((len_hrs * 3600 ) / self . ds .frequency .seconds )
49+ self .num_steps_per_window = int ((len_hrs * 3600 ) / ds .frequency .seconds )
4750
4851 # open dataset
4952
5053 # caches lats and lons
51- self .latitudes = self . ds .latitudes .astype (np .float32 )
52- self .longitudes = self . ds .longitudes .astype (np .float32 )
54+ self .latitudes = ds .latitudes .astype (np .float32 )
55+ self .longitudes = ds .longitudes .astype (np .float32 )
5356
5457 # TODO: define in base class
5558 self .geoinfo_idx = []
@@ -59,8 +62,8 @@ def __init__(
5962 source_channels = stream_info ["source" ] if "source" in stream_info else None
6063 self .source_idx = np .sort (
6164 [
62- self . ds .name_to_index [k ]
63- for i , (k , v ) in enumerate (self . ds .typed_variables .items ())
65+ ds .name_to_index [k ]
66+ for i , (k , v ) in enumerate (ds .typed_variables .items ())
6467 if (
6568 not v .is_computed_forcing
6669 and not v .is_constant_in_time
@@ -75,8 +78,8 @@ def __init__(
7578 target_channels = stream_info ["target" ] if "target" in stream_info else None
7679 self .target_idx = np .sort (
7780 [
78- self . ds .name_to_index [k ]
79- for i , (k , v ) in enumerate ( self . ds .typed_variables .items () )
81+ ds .name_to_index [k ]
82+ for (k , v ) in ds .typed_variables .items ()
8083 if (
8184 not v .is_computed_forcing
8285 and not v .is_constant_in_time
@@ -88,21 +91,20 @@ def __init__(
8891 )
8992 ]
9093 )
91- self .source_channels = [self . ds .variables [i ] for i in self .source_idx ]
92- self .target_channels = [self . ds .variables [i ] for i in self .target_idx ]
94+ self .source_channels = [ds .variables [i ] for i in self .source_idx ]
95+ self .target_channels = [ds .variables [i ] for i in self .target_idx ]
9396
9497 self .properties = {
9598 "stream_id" : 0 ,
9699 }
97- self .mean = self . ds .statistics ["mean" ]
98- self .stdev = self . ds .statistics ["stdev" ]
100+ self .mean = ds .statistics ["mean" ]
101+ self .stdev = ds .statistics ["stdev" ]
99102
100103 # set dataset to None when no overlap with time range
101104 if dt_start >= ds_dt_end or dt_end <= ds_dt_start :
102105 self .ds = None
103- return
104-
105- self .ds = open_dataset (self .ds , frequency = str (step_hrs ) + "h" , start = dt_start , end = dt_end )
106+ else :
107+ self .ds = open_dataset (ds , frequency = str (step_hrs ) + "h" , start = dt_start , end = dt_end )
106108
107109 def __len__ (self ):
108110 "Length of dataset"
@@ -140,8 +142,10 @@ def _get(
140142 )
141143
142144 # extract number of time steps and collapse ensemble dimension
145+
143146 data = self .ds [idx : idx + self .num_steps_per_window ][:, :, 0 ]
144- # extract channels
147+
148+ # # extract channels
145149 data = (
146150 data [:, channels_idx ].transpose ([0 , 2 , 1 ]).reshape ((data .shape [0 ] * data .shape [2 ], - 1 ))
147151 )
0 commit comments