Skip to content

Commit 7306a7a

Browse files
authored
Fix to bug in counting datasets contributing to loss (ecmwf#150)
* Fix to bug in counting datasets contributing to loss * Added some comments and improved clarity of code structure a bit
1 parent a093765 commit 7306a7a

File tree

1 file changed

+19
-17
lines changed

1 file changed

+19
-17
lines changed

src/weathergen/train/trainer.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def compute_loss(
492492
for fstep in range(forecast_steps + 1)
493493
]
494494

495-
ctr = 0
495+
ctr_ftarget = 0
496496
loss = torch.tensor(0.0, device=self.devices[0], requires_grad=True)
497497

498498
# assert len(targets_rt) == len(preds) and len(preds) == len(self.cf.streams)
@@ -520,10 +520,9 @@ def compute_loss(
520520
if target.shape[0] > 0 and pred.shape[0] > 0:
521521
# extract data/coords and remove token dimension if it exists
522522
pred = pred.reshape([pred.shape[0], *target.shape])
523+
assert pred.shape[1] > 0
523524

524525
mask_nan = ~torch.isnan(target)
525-
526-
assert pred.shape[1] > 0
527526
if pred[:, mask_nan].shape[1] == 0:
528527
continue
529528
ens = pred.shape[0] > 1
@@ -532,16 +531,16 @@ def compute_loss(
532531
for j, (loss_fct, w) in enumerate(loss_fcts):
533532
# compute per channel loss
534533
# val_uw is unweighted loss for logging
535-
val, val_uw, ctr = (
536-
torch.tensor(0.0, device=self.devices[0], requires_grad=True),
537-
0.0,
538-
0.0,
539-
)
534+
val = torch.tensor(0.0, device=self.devices[0], requires_grad=True)
535+
val_uw = 0.0
536+
ctr_chs = 0.0
537+
538+
# loop over all channels
540539
for i in range(target.shape[-1]):
540+
# if stream is internal time step, compute loss separately per step
541541
if tok_spacetime:
542542
# iterate over time steps and compute loss separately for each
543543
t_unique = torch.unique(target_coords[:, 1])
544-
# tw = np.linspace( 1.0, 2.0, len(t_unique))
545544
for _jj, t in enumerate(t_unique):
546545
mask_t = t == target_coords[:, 1]
547546
mask = torch.logical_and(mask_t, mask_nan[:, i])
@@ -553,8 +552,8 @@ def compute_loss(
553552
(pred[:, mask, i].std(0) if ens else torch.zeros(1)),
554553
)
555554
val_uw += temp.item()
556-
val = val + channel_loss_weight[i] * temp # * tw[jj]
557-
ctr += 1
555+
val = val + channel_loss_weight[i] * temp
556+
ctr_chs += 1
558557

559558
else:
560559
# only compute loss is there are non-NaN values
@@ -571,9 +570,9 @@ def compute_loss(
571570
)
572571
val_uw += temp.item()
573572
val = val + channel_loss_weight[i] * temp
574-
ctr += 1
575-
val = val / ctr if (ctr > 0) else val
576-
val_uw = val_uw / ctr if (ctr > 0) else val_uw
573+
ctr_chs += 1
574+
val = val / ctr_chs if (ctr_chs > 0) else val
575+
val_uw = val_uw / ctr_chs if (ctr_chs > 0) else val_uw
577576

578577
losses_all[j, i_obs] = val_uw
579578
if self.cf.loss_fcts[j][0] == "stats" or self.cf.loss_fcts[j][0] == "kcrps":
@@ -584,20 +583,23 @@ def compute_loss(
584583
if not torch.isnan(val)
585584
else torch.tensor(0.0, requires_grad=True)
586585
)
587-
ctr += 1
586+
ctr_ftarget += 1
588587

589588
# log data for analysis
590589
if log_data:
591-
# TODO: test
592590
targets_lens[fstep][i_obs] += [target.shape[0]]
593591
dn_data = self.dataset_val.denormalize_target_channels
594592

595593
f32 = torch.float32
596594
preds_all[fstep][i_obs] += [dn_data(i_obs, pred.to(f32)).detach().cpu()]
597595
targets_all[fstep][i_obs] += [dn_data(i_obs, target.to(f32)).detach().cpu()]
598596

597+
# normalize by all targets and forecast steps that were non-empty
598+
# (with each having an expected loss of 1 for an uninitalized neural net)
599+
loss /= ctr_ftarget
600+
599601
return (
600-
loss / ctr,
602+
loss,
601603
None
602604
if not log_data
603605
else [

0 commit comments

Comments
 (0)