Skip to content

Commit 22d867a

Browse files
JubekukacpnowaktjhunterJulian KuehnertMatKbauer
authored
Loss class refactoring (ecmwf#533)
* Fix bug with seed being divided by 0 for worker ID=0 * Fix bug causing crash when secrets aren't in private config * Implement logging losses per channel * Fix issue with empty targets * Rework loss logging * ruff * Remove computing max_channels * Change variables names * ruffed * Remove redundant enumerations * Use stages for logging * Add type hints * Apply the review * ruff * fix * Fix type hints * ruff * Implement sending tensors of different shapes * ruff * Fix merge * Fix docstring * rerun workflow * creating loss class * Adapted varnames in new compute_loss function to match LossModule * comments and loss_fcts refactoring * Suggested a separation of mask creation and loss computation * first working version of LossModule; added unit test * Modifications and TODOs after meeting with Christian and Julian * Added Christian's comments and updated code partially * Julian & Matze further advances to understand shapes * New mask_t computations. Not yet correct, thus commented * Resolved reshaping of tensors for loss computation * small changes in _prepare_logging * J&M first refactoring version finished, 2 tests ok * First round of resolving PR comments * add ModelLoss dataclass, rearrange mask and loss computation * Integrating new LossCalculator into trainer.py and adding docstrings * J&M resolved temp.item() error * Second round of PR comments integrated * - Fixed loss accumulation - Cleaned up variable names * Renamed weight * Removed unused vars * Inspected loss normalization for logging * Minor clean-up * Removing unused code. * More refactoring: breaking code down in smaller pieces * Fix * Adding missing copyright * Adding missing copyright * Fixing incorrect indent * Fix --------- Co-authored-by: Kacper Nowak <kacper.nowak@awi.de> Co-authored-by: Tim Hunter <tim.hunter@ecmwf.int> Co-authored-by: Julian Kuehnert <julian.kuehnert@ecwmf.int> Co-authored-by: Matthias Karlbauer <matthias.karlbauer@ecmwf.int> Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int> Co-authored-by: clessig <christian.lessig@ecwmf.int>
1 parent cc1f076 commit 22d867a

File tree

5 files changed

+493
-224
lines changed

5 files changed

+493
-224
lines changed

config/default_config.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
streams_directory: "./config/streams/streams_anemoi/"
1+
# streams_directory: "./config/streams/streams_anemoi/"
2+
streams_directory: "./config/streams/streams_mixed/"
23

34
embed_orientation: "channels"
45
embed_local_coords: True

packages/common/src/weathergen/common/io.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
# (C) Copyright 2025 WeatherGenerator contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
110
import dataclasses
211
import functools
312
import itertools
Lines changed: 353 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,353 @@
1+
# (C) Copyright 2025 WeatherGenerator contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
10+
import dataclasses
11+
import logging
12+
13+
import numpy as np
14+
import torch
15+
from omegaconf import DictConfig
16+
from torch import Tensor
17+
18+
import weathergen.train.loss as losses
19+
from weathergen.train.loss import stat_loss_fcts
20+
from weathergen.utils.train_logger import TRAIN, VAL, Stage
21+
22+
_logger = logging.getLogger(__name__)
23+
24+
25+
@dataclasses.dataclass
26+
class LossValues:
27+
"""
28+
A dataclass to encapsulate the various loss components computed by the LossCalculator.
29+
30+
This provides a structured way to return the primary loss used for optimization,
31+
along with detailed per-stream/per-channel/per-loss-function losses for logging,
32+
and standard deviations for ensemble scenarios.
33+
"""
34+
35+
# The primary scalar loss value for optimization.
36+
loss: Tensor
37+
# Dictionaries containing detailed loss values for each stream, channel, and loss function, as
38+
# well as standard deviations when operating with ensembles (e.g., when training with CRPS).
39+
losses_all: dict[str, Tensor]
40+
stddev_all: dict[str, Tensor]
41+
42+
43+
class LossCalculator:
44+
"""
45+
Manages and computes the overall loss for a WeatherGenerator model during
46+
training and validation stages.
47+
48+
This class handles the initialization and application of various loss functions,
49+
applies channel-specific weights, constructs masks for missing data, and
50+
aggregates losses across different data streams, channels, and forecast steps.
51+
It provides both the main loss for backpropagation and detailed loss metrics for logging.
52+
"""
53+
54+
def __init__(
55+
self,
56+
cf: DictConfig,
57+
stage: Stage,
58+
device: str,
59+
):
60+
"""
61+
Initializes the LossCalculator.
62+
63+
This sets up the configuration, the operational stage (training or validation),
64+
the device for tensor operations, and initializes the list of loss functions
65+
based on the provided configuration.
66+
67+
Args:
68+
cf: The OmegaConf DictConfig object containing model and training configurations.
69+
It should specify 'loss_fcts' for training and 'loss_fcts_val' for validation.
70+
stage: The current operational stage, either TRAIN or VAL.
71+
This dictates which set of loss functions (training or validation) will be used.
72+
device: The computation device, such as 'cpu' or 'cuda:0', where tensors will reside.
73+
"""
74+
self.cf = cf
75+
self.stage = stage
76+
self.device = device
77+
78+
# Dynamically load loss functions based on configuration and stage
79+
loss_fcts = cf.loss_fcts if stage == TRAIN else cf.loss_fcts_val
80+
self.loss_fcts = [[getattr(losses, name), w] for name, w in loss_fcts]
81+
82+
@staticmethod
83+
def _construct_masks(
84+
target_times_raw: np.array, mask_nan: Tensor, tok_spacetime: bool
85+
) -> list[Tensor]:
86+
"""
87+
Constructs a list of boolean masks for target data.
88+
89+
If 'tok_spacetime' is enabled, masks are generated for unique intermediate time steps
90+
within a single forecast step and combined with a NaN mask. Otherwise, a single mask
91+
for non-NaN values is returned. This is useful for datasets where targets might have
92+
sub-timestep granularity.
93+
94+
Args:
95+
target_times_raw: A NumPy array containing raw time values for targets
96+
within a single forecast step.
97+
mask_nan: A PyTorch Tensor indicating non-NaN values for the specific channel.
98+
tok_spacetime: A boolean flag indicating whether spacetime tokenization is active,
99+
which influences mask construction.
100+
101+
Returns:
102+
A list of PyTorch boolean Tensors, where each tensor is a combined mask for
103+
a unique time point or simply the non-NaN mask.
104+
"""
105+
masks = []
106+
if tok_spacetime:
107+
t_unique = np.unique(target_times_raw)
108+
for t in t_unique:
109+
mask_t = Tensor(t == target_times_raw).to(mask_nan)
110+
masks.append(torch.logical_and(mask_t, mask_nan))
111+
else:
112+
masks.append(mask_nan)
113+
return masks
114+
115+
@staticmethod
116+
def _compute_loss_with_mask(
117+
target: Tensor, pred: Tensor, mask: np.array, i_ch: int, loss_fct: losses, ens: bool
118+
) -> Tensor:
119+
"""
120+
Computes the loss for a specific channel using a given mask.
121+
122+
This helper function applies a chosen loss function to the masked target and prediction
123+
data for a single channel, handling ensemble predictions by calculating mean and standard
124+
deviation over the ensemble dimension.
125+
126+
Args:
127+
target: The ground truth target tensor.
128+
pred: The prediction tensor, potentially with an ensemble dimension.
129+
mask: A boolean mask tensor, indicating which elements to consider for loss computation.
130+
i_ch: The index of the channel for which to compute the loss.
131+
loss_fct: The specific loss function to apply. It is expected to accept
132+
(masked_target, masked_pred, pred_mean, pred_std).
133+
ens: A boolean flag indicating whether 'pred' contains an ensemble dimension.
134+
135+
Returns:
136+
The computed loss value for the masked data, or a tensor with value 0 if no
137+
valid data points are present under the mask.
138+
"""
139+
if mask.sum().item() > 0:
140+
# Only compute loss if there are non-NaN values
141+
return loss_fct(
142+
target[mask, i_ch],
143+
pred[:, mask, i_ch],
144+
pred[:, mask, i_ch].mean(0),
145+
(pred[:, mask, i_ch].std(0) if ens else torch.zeros(1, device=pred.device)),
146+
)
147+
else:
148+
# If no valid data under the mask, return 0 to avoid errors and not contribute to loss
149+
return torch.tensor(0.0, device=pred.device)
150+
151+
def _compute_loss_per_loss_function(
152+
self,
153+
loss_fct,
154+
i_lfct,
155+
i_batch,
156+
i_strm,
157+
strm,
158+
fstep,
159+
streams_data,
160+
target,
161+
pred,
162+
mask_nan,
163+
channel_loss_weight,
164+
losses_all,
165+
):
166+
tok_spacetime = strm["tokenize_spacetime"] if "tokenize_spacetime" in strm else False
167+
ens = pred.shape[0] > 1
168+
169+
# compute per channel loss
170+
loss_lfct = torch.tensor(0.0, device=self.device, requires_grad=True)
171+
ctr_chs = 0
172+
173+
# loop over all channels within the current stream and forecast step
174+
for i_ch in range(target.shape[-1]):
175+
# construct masks based on spacetime tokenization setting
176+
masks = self._construct_masks(
177+
target_times_raw=streams_data[i_batch][i_strm].target_times_raw[
178+
self.cf.forecast_offset + fstep
179+
],
180+
mask_nan=mask_nan[:, i_ch],
181+
tok_spacetime=tok_spacetime,
182+
)
183+
ctr_substeps = 0
184+
for mask in masks:
185+
loss_ch = self._compute_loss_with_mask(
186+
target=target,
187+
pred=pred,
188+
mask=mask,
189+
i_ch=i_ch,
190+
loss_fct=loss_fct,
191+
ens=ens,
192+
)
193+
# accumulate weighted loss for this loss function and channel
194+
loss_lfct = loss_lfct + (channel_loss_weight[i_ch] * loss_ch)
195+
ctr_chs += 1 if loss_ch > 0.0 else 0
196+
ctr_substeps += 1 if loss_ch > 0.0 else 0
197+
# for detailed logging
198+
losses_all[strm.name][i_ch, i_lfct] += loss_ch.item()
199+
200+
# normalize over forecast steps in window
201+
losses_all[strm.name][i_ch, i_lfct] /= ctr_substeps if ctr_substeps > 0 else 0.0
202+
203+
# normalize the accumulated loss for the current loss function
204+
loss_lfct = loss_lfct / ctr_chs if (ctr_chs > 0) else loss_lfct
205+
206+
return loss_lfct, losses_all
207+
208+
def compute_loss(
209+
self,
210+
preds: list[list[Tensor]],
211+
streams_data: list[
212+
list[any]
213+
], # Assuming Stream is a dataclass/object for each stream in a batch
214+
) -> LossValues:
215+
"""
216+
Computes the total loss for a given batch of predictions and corresponding
217+
stream data.
218+
219+
This method orchestrates the calculation of the overall loss by iterating through
220+
different data streams, forecast steps, channels, and configured loss functions.
221+
It applies weighting, handles NaN values through masking, and accumulates
222+
detailed loss metrics for logging.
223+
224+
Args:
225+
preds: A nested list of prediction tensors. The outer list represents forecast steps,
226+
the inner list represents streams. Each tensor contains predictions for that
227+
step and stream.
228+
streams_data: A nested list representing the input batch data. The outer list is for
229+
batch items, the inner list for streams. Each element provides an object
230+
(e.g., dataclass instance) containing target data and metadata.
231+
232+
Returns:
233+
A ModelLoss dataclass instance containing:
234+
- loss: The loss for back-propagation.
235+
- losses_all: A dictionary mapping stream names to a tensor of per-channel and
236+
per-loss-function losses, normalized by non-empty targets/forecast steps.
237+
- stddev_all: A dictionary mapping stream names to a tensor of mean standard deviations
238+
of predictions for channels with statistical loss functions, normalized.
239+
"""
240+
241+
# gradient loss
242+
loss = torch.tensor(0.0, device=self.device, requires_grad=True)
243+
# counter for non-empty targets
244+
ctr_streams = 0
245+
246+
# initialize dictionaries for detailed loss tracking and standard deviation statistics
247+
# create tensor for each stream
248+
losses_all: dict[str, Tensor] = {
249+
st.name: torch.zeros(
250+
(len(st[str(self.stage) + "_target_channels"]), len(self.loss_fcts)),
251+
device=self.device,
252+
)
253+
for st in self.cf.streams
254+
}
255+
stddev_all: dict[str, Tensor] = {
256+
st.name: torch.zeros(len(stat_loss_fcts), device=self.device) for st in self.cf.streams
257+
}
258+
259+
# TODO: iterate over batch dimension
260+
i_batch = 0
261+
for i_strm, strm in enumerate(self.cf.streams):
262+
# extract target tokens for current stream from the specified forecast offset onwards
263+
targets = streams_data[i_batch][i_strm].target_tokens[self.cf.forecast_offset :]
264+
265+
loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True)
266+
ctr_fsteps = 0
267+
268+
for fstep, target in enumerate(targets):
269+
# skip if either target or prediction has no data points
270+
pred = preds[fstep][i_strm]
271+
if not (target.shape[0] > 0 and pred.shape[0] > 0):
272+
continue
273+
274+
num_channels = len(strm[str(self.stage) + "_target_channels"])
275+
276+
# Determine stream and channel loss weights based on the current stage
277+
if self.stage == TRAIN:
278+
# set loss_weights to 1. when not specified
279+
strm_loss_weight = strm["loss_weight"] if "loss_weight" in strm else 1.0
280+
channel_loss_weight = (
281+
strm["channel_weight"]
282+
if "channel_weight" in strm
283+
else np.ones(num_channels)
284+
)
285+
elif self.stage == VAL:
286+
# in validation mode, always unweighted loss
287+
strm_loss_weight = 1.0
288+
channel_loss_weight = np.ones(num_channels)
289+
290+
# reshape prediction tensor to match target's dimensions: extract data/coords and
291+
# remove token dimension if it exists.
292+
# expected final shape of pred is [ensemble_size, num_samples, num_channels].
293+
pred = pred.reshape([pred.shape[0], *target.shape])
294+
assert pred.shape[1] > 0
295+
296+
mask_nan = ~torch.isnan(target)
297+
# if all valid predictions are masked out by NaNs, skip this forecast step
298+
if pred[:, mask_nan].shape[1] == 0:
299+
continue
300+
301+
# accumulate loss from different loss functions and across channels
302+
for i_lfct, (loss_fct, loss_fct_weight) in enumerate(self.loss_fcts):
303+
loss_lfct, losses_all = self._compute_loss_per_loss_function(
304+
loss_fct,
305+
i_lfct,
306+
i_batch,
307+
i_strm,
308+
strm,
309+
fstep,
310+
streams_data,
311+
target,
312+
pred,
313+
mask_nan,
314+
channel_loss_weight,
315+
losses_all,
316+
)
317+
318+
# Update statistical deviation metrics if the current loss function is
319+
# recognized as statistical
320+
if loss_fct.__name__ in stat_loss_fcts:
321+
indx = stat_loss_fcts.index(loss_fct.__name__)
322+
stddev_all[strm.name][indx] += pred[:, mask_nan].std(0).mean().item()
323+
324+
# Add the weighted and normalized loss from this loss function to the total
325+
# batch loss
326+
loss_fstep = loss_fstep + (loss_fct_weight * loss_lfct * strm_loss_weight)
327+
ctr_fsteps += 1 if loss_lfct > 0.0 else 0
328+
329+
loss = loss + loss_fstep / ctr_fsteps if ctr_fsteps > 0 else loss
330+
ctr_streams += 1 if loss_fstep > 0 else 0
331+
332+
# normalize by forecast step
333+
losses_all[strm.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0
334+
stddev_all[strm.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0
335+
336+
# replace channels without information by nan to exclude from further computations
337+
losses_all[strm.name][losses_all[strm.name] == 0.0] = torch.nan
338+
stddev_all[strm.name][stddev_all[strm.name] == 0.0] = torch.nan
339+
340+
if loss == 0.0:
341+
# streams_data[i] are samples in batch
342+
# streams_data[i][0] is stream 0 (sample_idx is identical for all streams per sample)
343+
_logger.warning(
344+
f"Loss is 0.0 for sample(s): {[sd[0].sample_idx.item() for sd in streams_data]}."
345+
+ "This will likely lead to errors in the optimization step."
346+
)
347+
348+
# normalize by all targets and forecast steps that were non-empty
349+
# (with each having an expected loss of 1 for an uninitalized neural net)
350+
loss = loss / ctr_streams
351+
352+
# Return all computed loss components encapsulated in a ModelLoss dataclass
353+
return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all)

0 commit comments

Comments
 (0)