@@ -492,7 +492,7 @@ def compute_loss(
492492 for fstep in range (forecast_steps + 1 )
493493 ]
494494
495- ctr = 0
495+ ctr_ftarget = 0
496496 loss = torch .tensor (0.0 , device = self .devices [0 ], requires_grad = True )
497497
498498 # assert len(targets_rt) == len(preds) and len(preds) == len(self.cf.streams)
@@ -520,10 +520,9 @@ def compute_loss(
520520 if target .shape [0 ] > 0 and pred .shape [0 ] > 0 :
521521 # extract data/coords and remove token dimension if it exists
522522 pred = pred .reshape ([pred .shape [0 ], * target .shape ])
523+ assert pred .shape [1 ] > 0
523524
524525 mask_nan = ~ torch .isnan (target )
525-
526- assert pred .shape [1 ] > 0
527526 if pred [:, mask_nan ].shape [1 ] == 0 :
528527 continue
529528 ens = pred .shape [0 ] > 1
@@ -532,16 +531,16 @@ def compute_loss(
532531 for j , (loss_fct , w ) in enumerate (loss_fcts ):
533532 # compute per channel loss
534533 # val_uw is unweighted loss for logging
535- val , val_uw , ctr = (
536- torch . tensor ( 0.0 , device = self . devices [ 0 ], requires_grad = True ),
537- 0.0 ,
538- 0.0 ,
539- )
534+ val = torch . tensor ( 0.0 , device = self . devices [ 0 ], requires_grad = True )
535+ val_uw = 0.0
536+ ctr_chs = 0.0
537+
538+ # loop over all channels
540539 for i in range (target .shape [- 1 ]):
540+ # if stream is internal time step, compute loss separately per step
541541 if tok_spacetime :
542542 # iterate over time steps and compute loss separately for each
543543 t_unique = torch .unique (target_coords [:, 1 ])
544- # tw = np.linspace( 1.0, 2.0, len(t_unique))
545544 for _jj , t in enumerate (t_unique ):
546545 mask_t = t == target_coords [:, 1 ]
547546 mask = torch .logical_and (mask_t , mask_nan [:, i ])
@@ -553,8 +552,8 @@ def compute_loss(
553552 (pred [:, mask , i ].std (0 ) if ens else torch .zeros (1 )),
554553 )
555554 val_uw += temp .item ()
556- val = val + channel_loss_weight [i ] * temp # * tw[jj]
557- ctr += 1
555+ val = val + channel_loss_weight [i ] * temp
556+ ctr_chs += 1
558557
559558 else :
560559 # only compute loss is there are non-NaN values
@@ -571,9 +570,9 @@ def compute_loss(
571570 )
572571 val_uw += temp .item ()
573572 val = val + channel_loss_weight [i ] * temp
574- ctr += 1
575- val = val / ctr if (ctr > 0 ) else val
576- val_uw = val_uw / ctr if (ctr > 0 ) else val_uw
573+ ctr_chs += 1
574+ val = val / ctr_chs if (ctr_chs > 0 ) else val
575+ val_uw = val_uw / ctr_chs if (ctr_chs > 0 ) else val_uw
577576
578577 losses_all [j , i_obs ] = val_uw
579578 if self .cf .loss_fcts [j ][0 ] == "stats" or self .cf .loss_fcts [j ][0 ] == "kcrps" :
@@ -584,20 +583,23 @@ def compute_loss(
584583 if not torch .isnan (val )
585584 else torch .tensor (0.0 , requires_grad = True )
586585 )
587- ctr += 1
586+ ctr_ftarget += 1
588587
589588 # log data for analysis
590589 if log_data :
591- # TODO: test
592590 targets_lens [fstep ][i_obs ] += [target .shape [0 ]]
593591 dn_data = self .dataset_val .denormalize_target_channels
594592
595593 f32 = torch .float32
596594 preds_all [fstep ][i_obs ] += [dn_data (i_obs , pred .to (f32 )).detach ().cpu ()]
597595 targets_all [fstep ][i_obs ] += [dn_data (i_obs , target .to (f32 )).detach ().cpu ()]
598596
597+ # normalize by all targets and forecast steps that were non-empty
598+ # (with each having an expected loss of 1 for an uninitalized neural net)
599+ loss /= ctr_ftarget
600+
599601 return (
600- loss / ctr ,
602+ loss ,
601603 None
602604 if not log_data
603605 else [
0 commit comments