1111import logging
1212
1313import numpy as np
14+ import torch
1415from anemoi .datasets import open_dataset
1516
1617_logger = logging .getLogger (__name__ )
@@ -28,6 +29,29 @@ def __init__(
2829 filename : str ,
2930 stream_info : dict ,
3031 ) -> None :
32+ """
33+ Construct dataset based on anemoi dataset
34+
35+ Parameters
36+ ----------
37+ start : int
38+ Start time
39+ end : int
40+ End time
41+ len_hrs : int
42+ length of data window
43+ step_hrs :
44+ delta hours between start times of windows
45+ filename :
46+ filename (and path) of dataset
47+ stream_info :
48+ information about stream
49+
50+ Returns
51+ -------
52+ None
53+ """
54+
3155 # TODO: add support for different normalization modes
3256
3357 assert len_hrs == step_hrs , "Currently only step_hrs=len_hrs is supported"
@@ -106,31 +130,69 @@ def __init__(
106130 else :
107131 self .ds = open_dataset (ds , frequency = str (step_hrs ) + "h" , start = dt_start , end = dt_end )
108132
109- def __len__ (self ):
110- "Length of dataset"
133+ def __len__ (self ) -> int :
134+ """
135+ Length of dataset
136+
137+ Parameters
138+ ----------
139+ None
111140
141+ Returns
142+ -------
143+ length of dataset
144+ """
112145 if not self .ds :
113146 return 0
114147
115148 return len (self .ds )
116149
117150 def get_source (self , idx : int ) -> tuple [np .array , np .array , np .array , np .array ]:
118151 """
119- TODO
152+ Get source data for idx
153+
154+ Parameters
155+ ----------
156+ idx : int
157+ Index of temporal window
158+
159+ Returns
160+ -------
161+ source data (coords, geoinfos, data, datetimes)
120162 """
121163 return self ._get (idx , self .source_idx )
122164
123165 def get_target (self , idx : int ) -> tuple [np .array , np .array , np .array , np .array ]:
124166 """
125- TODO
167+ Get target data for idx
168+
169+ Parameters
170+ ----------
171+ idx : int
172+ Index of temporal window
173+
174+ Returns
175+ -------
176+ target data (coords, geoinfos, data, datetimes)
126177 """
127178 return self ._get (idx , self .target_idx )
128179
129180 def _get (
130181 self , idx : int , channels_idx : np .array
131182 ) -> tuple [np .array , np .array , np .array , np .array ]:
132183 """
133- TODO
184+ Get data for window
185+
186+ Parameters
187+ ----------
188+ idx : int
189+ Index of temporal window
190+ channels_idx : np.array
191+ Selection of channels
192+
193+ Returns
194+ -------
195+ data (coords, geoinfos, data, datetimes)
134196 """
135197
136198 if not self .ds :
@@ -172,74 +234,186 @@ def _get(
172234
173235 return (latlon , geoinfos , data , datetimes )
174236
175- def get_source_size (self ):
176- """
177- TODO
237+ def get_source_num_channels (self ) -> int :
178238 """
179- return 2 + len ( self . geoinfo_idx ) + len ( self . source_idx )
239+ Get number of source channels
180240
181- def get_source_num_channels (self ):
182- """
183- TODO
241+ Parameters
242+ ----------
243+ None
244+
245+ Returns
246+ -------
247+ number of source channels
184248 """
185249 return len (self .source_idx )
186250
187- def get_target_size (self ):
251+ def get_target_num_channels (self ) -> int :
188252 """
189- TODO
253+ Get number of target channels
254+
255+ Parameters
256+ ----------
257+ None
258+
259+ Returns
260+ -------
261+ number of target channels
190262 """
191- return 2 + len ( self . geoinfo_idx ) + len (self .target_idx )
263+ return len (self .target_idx )
192264
193- def get_target_num_channels (self ):
265+ def get_coords_size (self ) -> int :
194266 """
195- TODO
267+ Get size of coords
268+
269+ Parameters
270+ ----------
271+ None
272+
273+ Returns
274+ -------
275+ size of coords
196276 """
197- return len ( self . target_idx )
277+ return 2
198278
199- def get_geoinfo_size (self ):
279+ def get_geoinfo_size (self ) -> int :
200280 """
201- TODO
281+ Get size of geoinfos
282+
283+ Parameters
284+ ----------
285+ None
286+
287+ Returns
288+ -------
289+ size of geoinfos
202290 """
203291 return len (self .geoinfo_idx )
204292
205- def normalize_coords (self , coords ) :
293+ def normalize_coords (self , coords : torch . tensor ) -> torch . tensor :
206294 """
207- TODO
295+ Normalize coordinates
296+
297+ Parameters
298+ ----------
299+ coords :
300+ coordinates to be normalized
301+
302+ Returns
303+ -------
304+ Normalized coordinates
208305 """
209306 coords [..., 0 ] = np .sin (np .deg2rad (coords [..., 0 ]))
210307 coords [..., 1 ] = np .sin (0.5 * np .deg2rad (coords [..., 1 ]))
211308
212309 return coords
213310
214- def normalize_geoinfos (self , geoinfos ) :
311+ def normalize_geoinfos (self , geoinfos : torch . tensor ) -> torch . tensor :
215312 """
216- TODO
313+ Normalize geoinfos
314+
315+ Parameters
316+ ----------
317+ geoinfos :
318+ geoinfos to be normalized
319+
320+ Returns
321+ -------
322+ Normalized geoinfo
217323 """
218324
219- assert geoinfos .shape [- 1 ] == 0
325+ assert geoinfos .shape [- 1 ] == 0 , "incorrect number of geoinfo channels"
220326 return geoinfos
221327
222- def normalize_source_channels (self , source ) :
328+ def normalize_source_channels (self , source : torch . tensor ) -> torch . tensor :
223329 """
224- TODO
330+ Normalize source channels
331+
332+ Parameters
333+ ----------
334+ data :
335+ data to be normalized
336+
337+ Returns
338+ -------
339+ Normalized data
225340 """
226- assert source .shape [1 ] == len (self .source_idx )
341+ assert source .shape [- 1 ] == len (self .source_idx ), "incorrect number of channels"
227342 for i , ch in enumerate (self .source_idx ):
228343 source [..., i ] = (source [..., i ] - self .mean [ch ]) / self .stdev [ch ]
229344
230345 return source
231346
232- def normalize_target_channels (self , target ) :
347+ def normalize_target_channels (self , target : torch . tensor ) -> torch . tensor :
233348 """
234- TODO
349+ Normalize target channels
350+
351+ Parameters
352+ ----------
353+ data :
354+ data to be normalized
355+
356+ Returns
357+ -------
358+ Normalized data
235359 """
236- assert target .shape [1 ] == len (self .target_idx )
360+ assert target .shape [- 1 ] == len (self .target_idx ), "incorrect number of channels"
237361 for i , ch in enumerate (self .target_idx ):
238362 target [..., i ] = (target [..., i ] - self .mean [ch ]) / self .stdev [ch ]
239363
240364 return target
241365
366+ def denormalize_source_channels (self , source : torch .tensor ) -> torch .tensor :
367+ """
368+ Denormalize source channels
369+
370+ Parameters
371+ ----------
372+ data :
373+ data to be denormalized
374+
375+ Returns
376+ -------
377+ Denormalized data
378+ """
379+ assert source .shape [- 1 ] == len (self .source_idx ), "incorrect number of channels"
380+ for i , ch in enumerate (self .source_idx ):
381+ source [..., i ] = (source [..., i ] * self .stdev [ch ]) + self .mean [ch ]
382+
383+ return source
384+
385+ def denormalize_target_channels (self , data : torch .tensor ) -> torch .tensor :
386+ """
387+ Denormalize target channels
388+
389+ Parameters
390+ ----------
391+ data :
392+ data to be denormalized (target or pred)
393+
394+ Returns
395+ -------
396+ Denormalized data
397+ """
398+ assert data .shape [- 1 ] == len (self .target_idx ), "incorrect number of channels"
399+ for i , ch in enumerate (self .target_idx ):
400+ data [..., i ] = (data [..., i ] * self .stdev [ch ]) + self .mean [ch ]
401+
402+ return data
403+
242404 def time_window (self , idx : int ) -> tuple [np .datetime64 , np .datetime64 ]:
405+ """
406+ Temporal window corresponding to index
407+
408+ Parameters
409+ ----------
410+ idx :
411+ index of temporal window
412+
413+ Returns
414+ -------
415+ start and end of temporal window
416+ """
243417 if not self .ds :
244418 return (np .array ([], dtype = np .datetime64 ), np .array ([], dtype = np .datetime64 ))
245419
0 commit comments