Skip to content

Commit fb6c2e5

Browse files
authored
Fix target times and handling of forecast steps (ecmwf#113)
* Working on having absolute time for targets directly available (currently only rel time in target window is) * Fixed docstrings * Ruffed. * - Fixed output to correctly handle multiple forecast steps - Changed handling of target time and coords for output so that these are passed now from dataloader (changes still need to be verified on reading side) * Fixing bug introduced in PR ecmwf#102 where target coords where no longer properly aligned with target. * Fixing bug that prevented training from running. Also fixing bug in target_lens for writing. * Enabling assert to make sure NaNs in target coords are handled correctly in the future. * Fixing formatting and linting. * Improved docstring and documentation of code. * Removed unused variable
1 parent 1ecf161 commit fb6c2e5

File tree

6 files changed

+356
-153
lines changed

6 files changed

+356
-153
lines changed

src/weathergen/datasets/anemoi_dataset.py

Lines changed: 205 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import logging
1212

1313
import numpy as np
14+
import torch
1415
from anemoi.datasets import open_dataset
1516

1617
_logger = logging.getLogger(__name__)
@@ -28,6 +29,29 @@ def __init__(
2829
filename: str,
2930
stream_info: dict,
3031
) -> None:
32+
"""
33+
Construct dataset based on anemoi dataset
34+
35+
Parameters
36+
----------
37+
start : int
38+
Start time
39+
end : int
40+
End time
41+
len_hrs : int
42+
length of data window
43+
step_hrs :
44+
delta hours between start times of windows
45+
filename :
46+
filename (and path) of dataset
47+
stream_info :
48+
information about stream
49+
50+
Returns
51+
-------
52+
None
53+
"""
54+
3155
# TODO: add support for different normalization modes
3256

3357
assert len_hrs == step_hrs, "Currently only step_hrs=len_hrs is supported"
@@ -106,31 +130,69 @@ def __init__(
106130
else:
107131
self.ds = open_dataset(ds, frequency=str(step_hrs) + "h", start=dt_start, end=dt_end)
108132

109-
def __len__(self):
110-
"Length of dataset"
133+
def __len__(self) -> int:
134+
"""
135+
Length of dataset
136+
137+
Parameters
138+
----------
139+
None
111140
141+
Returns
142+
-------
143+
length of dataset
144+
"""
112145
if not self.ds:
113146
return 0
114147

115148
return len(self.ds)
116149

117150
def get_source(self, idx: int) -> tuple[np.array, np.array, np.array, np.array]:
118151
"""
119-
TODO
152+
Get source data for idx
153+
154+
Parameters
155+
----------
156+
idx : int
157+
Index of temporal window
158+
159+
Returns
160+
-------
161+
source data (coords, geoinfos, data, datetimes)
120162
"""
121163
return self._get(idx, self.source_idx)
122164

123165
def get_target(self, idx: int) -> tuple[np.array, np.array, np.array, np.array]:
124166
"""
125-
TODO
167+
Get target data for idx
168+
169+
Parameters
170+
----------
171+
idx : int
172+
Index of temporal window
173+
174+
Returns
175+
-------
176+
target data (coords, geoinfos, data, datetimes)
126177
"""
127178
return self._get(idx, self.target_idx)
128179

129180
def _get(
130181
self, idx: int, channels_idx: np.array
131182
) -> tuple[np.array, np.array, np.array, np.array]:
132183
"""
133-
TODO
184+
Get data for window
185+
186+
Parameters
187+
----------
188+
idx : int
189+
Index of temporal window
190+
channels_idx : np.array
191+
Selection of channels
192+
193+
Returns
194+
-------
195+
data (coords, geoinfos, data, datetimes)
134196
"""
135197

136198
if not self.ds:
@@ -172,74 +234,186 @@ def _get(
172234

173235
return (latlon, geoinfos, data, datetimes)
174236

175-
def get_source_size(self):
176-
"""
177-
TODO
237+
def get_source_num_channels(self) -> int:
178238
"""
179-
return 2 + len(self.geoinfo_idx) + len(self.source_idx)
239+
Get number of source channels
180240
181-
def get_source_num_channels(self):
182-
"""
183-
TODO
241+
Parameters
242+
----------
243+
None
244+
245+
Returns
246+
-------
247+
number of source channels
184248
"""
185249
return len(self.source_idx)
186250

187-
def get_target_size(self):
251+
def get_target_num_channels(self) -> int:
188252
"""
189-
TODO
253+
Get number of target channels
254+
255+
Parameters
256+
----------
257+
None
258+
259+
Returns
260+
-------
261+
number of target channels
190262
"""
191-
return 2 + len(self.geoinfo_idx) + len(self.target_idx)
263+
return len(self.target_idx)
192264

193-
def get_target_num_channels(self):
265+
def get_coords_size(self) -> int:
194266
"""
195-
TODO
267+
Get size of coords
268+
269+
Parameters
270+
----------
271+
None
272+
273+
Returns
274+
-------
275+
size of coords
196276
"""
197-
return len(self.target_idx)
277+
return 2
198278

199-
def get_geoinfo_size(self):
279+
def get_geoinfo_size(self) -> int:
200280
"""
201-
TODO
281+
Get size of geoinfos
282+
283+
Parameters
284+
----------
285+
None
286+
287+
Returns
288+
-------
289+
size of geoinfos
202290
"""
203291
return len(self.geoinfo_idx)
204292

205-
def normalize_coords(self, coords):
293+
def normalize_coords(self, coords: torch.tensor) -> torch.tensor:
206294
"""
207-
TODO
295+
Normalize coordinates
296+
297+
Parameters
298+
----------
299+
coords :
300+
coordinates to be normalized
301+
302+
Returns
303+
-------
304+
Normalized coordinates
208305
"""
209306
coords[..., 0] = np.sin(np.deg2rad(coords[..., 0]))
210307
coords[..., 1] = np.sin(0.5 * np.deg2rad(coords[..., 1]))
211308

212309
return coords
213310

214-
def normalize_geoinfos(self, geoinfos):
311+
def normalize_geoinfos(self, geoinfos: torch.tensor) -> torch.tensor:
215312
"""
216-
TODO
313+
Normalize geoinfos
314+
315+
Parameters
316+
----------
317+
geoinfos :
318+
geoinfos to be normalized
319+
320+
Returns
321+
-------
322+
Normalized geoinfo
217323
"""
218324

219-
assert geoinfos.shape[-1] == 0
325+
assert geoinfos.shape[-1] == 0, "incorrect number of geoinfo channels"
220326
return geoinfos
221327

222-
def normalize_source_channels(self, source):
328+
def normalize_source_channels(self, source: torch.tensor) -> torch.tensor:
223329
"""
224-
TODO
330+
Normalize source channels
331+
332+
Parameters
333+
----------
334+
data :
335+
data to be normalized
336+
337+
Returns
338+
-------
339+
Normalized data
225340
"""
226-
assert source.shape[1] == len(self.source_idx)
341+
assert source.shape[-1] == len(self.source_idx), "incorrect number of channels"
227342
for i, ch in enumerate(self.source_idx):
228343
source[..., i] = (source[..., i] - self.mean[ch]) / self.stdev[ch]
229344

230345
return source
231346

232-
def normalize_target_channels(self, target):
347+
def normalize_target_channels(self, target: torch.tensor) -> torch.tensor:
233348
"""
234-
TODO
349+
Normalize target channels
350+
351+
Parameters
352+
----------
353+
data :
354+
data to be normalized
355+
356+
Returns
357+
-------
358+
Normalized data
235359
"""
236-
assert target.shape[1] == len(self.target_idx)
360+
assert target.shape[-1] == len(self.target_idx), "incorrect number of channels"
237361
for i, ch in enumerate(self.target_idx):
238362
target[..., i] = (target[..., i] - self.mean[ch]) / self.stdev[ch]
239363

240364
return target
241365

366+
def denormalize_source_channels(self, source: torch.tensor) -> torch.tensor:
367+
"""
368+
Denormalize source channels
369+
370+
Parameters
371+
----------
372+
data :
373+
data to be denormalized
374+
375+
Returns
376+
-------
377+
Denormalized data
378+
"""
379+
assert source.shape[-1] == len(self.source_idx), "incorrect number of channels"
380+
for i, ch in enumerate(self.source_idx):
381+
source[..., i] = (source[..., i] * self.stdev[ch]) + self.mean[ch]
382+
383+
return source
384+
385+
def denormalize_target_channels(self, data: torch.tensor) -> torch.tensor:
386+
"""
387+
Denormalize target channels
388+
389+
Parameters
390+
----------
391+
data :
392+
data to be denormalized (target or pred)
393+
394+
Returns
395+
-------
396+
Denormalized data
397+
"""
398+
assert data.shape[-1] == len(self.target_idx), "incorrect number of channels"
399+
for i, ch in enumerate(self.target_idx):
400+
data[..., i] = (data[..., i] * self.stdev[ch]) + self.mean[ch]
401+
402+
return data
403+
242404
def time_window(self, idx: int) -> tuple[np.datetime64, np.datetime64]:
405+
"""
406+
Temporal window corresponding to index
407+
408+
Parameters
409+
----------
410+
idx :
411+
index of temporal window
412+
413+
Returns
414+
-------
415+
start and end of temporal window
416+
"""
243417
if not self.ds:
244418
return (np.array([], dtype=np.datetime64), np.array([], dtype=np.datetime64))
245419

src/weathergen/datasets/batchifyer.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -406,19 +406,25 @@ def batchify_target(
406406
)
407407
hpy_idxs_ord_split = np.split(hpy_idxs_ord, splits + 1)
408408

409-
times = encode_times_target(times, time_win)
409+
times_enc = encode_times_target(times, time_win)
410410

411411
target_tokens = [torch.tensor([]) for _ in range(self.num_healpix_cells_target)]
412+
target_coords_raw = [torch.tensor([]) for _ in range(self.num_healpix_cells_target)]
412413
target_coords = [torch.tensor([]) for _ in range(self.num_healpix_cells_target)]
413414
target_geoinfos = [torch.tensor([]) for _ in range(self.num_healpix_cells_target)]
415+
target_times_raw = [torch.tensor([]) for _ in range(self.num_healpix_cells_target)]
414416
target_times = [torch.tensor([]) for _ in range(self.num_healpix_cells_target)]
415417
for i, c in enumerate(cells_idxs):
416418
t = normalizer.normalize_target_channels(source[hpy_idxs_ord_split[i]])
417-
t = t[self.rng.permutation(len(t))][: int(len(t) * sampling_rate_target)]
418-
target_tokens[c] = t
419-
target_coords[c] = coords[hpy_idxs_ord_split[i]]
420-
target_geoinfos[c] = normalizer.normalize_geoinfos(geoinfos[hpy_idxs_ord_split[i]])
421-
target_times[c] = times[hpy_idxs_ord_split[i]]
419+
perm = self.rng.permutation(len(t))[: int(len(t) * sampling_rate_target)]
420+
target_tokens[c] = t[perm]
421+
target_coords[c] = coords[hpy_idxs_ord_split[i]][perm]
422+
target_coords_raw[c] = coords[hpy_idxs_ord_split[i]][perm]
423+
target_geoinfos[c] = normalizer.normalize_geoinfos(
424+
geoinfos[hpy_idxs_ord_split[i]][perm]
425+
)
426+
target_times_raw[c] = times[hpy_idxs_ord_split[i]][perm]
427+
target_times[c] = times_enc[hpy_idxs_ord_split[i]][perm]
422428

423429
target_tokens_lens = torch.tensor([len(s) for s in target_tokens], dtype=torch.int32)
424430

@@ -436,4 +442,4 @@ def batchify_target(
436442
target_coords.requires_grad = False
437443
target_coords = list(target_coords.split(target_tokens_lens.tolist()))
438444

439-
return (target_tokens, target_coords)
445+
return (target_tokens, target_coords, target_coords_raw, target_times_raw)

0 commit comments

Comments
 (0)