Skip to content

Commit c42c3e7

Browse files
authored
Misc fixes (ecmwf#171)
* Fixed bug in handling of forecast step for rollout. * Fixed problems where in complete configs were saved. * Fixed problems in handling of forecast policy * Ruffed
1 parent 59ca2ad commit c42c3e7

File tree

3 files changed

+15
-13
lines changed

3 files changed

+15
-13
lines changed

src/weathergen/datasets/multi_stream_data_sampler.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,16 @@ def __init__(self, cf, start_date, end_date, batch_size, samples_per_epoch, shuf
3939
self.len_hrs = cf.len_hrs
4040
self.step_hrs = cf.step_hrs
4141

42-
fc_policy_seq = (
43-
cf.forecast_policy == "sequential" or cf.forecast_policy == "sequential_random"
44-
)
45-
assert cf.forecast_steps >= 0 if not fc_policy_seq else True
4642
self.forecast_delta_hrs = (
4743
cf.forecast_delta_hrs if cf.forecast_delta_hrs > 0 else self.len_hrs
4844
)
4945
assert self.forecast_delta_hrs == self.len_hrs, "Only supported option at the moment"
5046
self.forecast_steps = np.array(
5147
[cf.forecast_steps] if type(cf.forecast_steps) == int else cf.forecast_steps
5248
)
49+
if cf.forecast_policy is not None:
50+
if self.forecast_steps.max() == 0:
51+
logger.warning("forecast policy is not None but number of forecast steps is 0.")
5352
self.forecast_policy = cf.forecast_policy
5453

5554
# end date needs to be adjusted to account for window length
@@ -194,7 +193,11 @@ def get_targets_coords_size(self):
194193

195194
###################################################
196195
def reset(self):
197-
fsm = self.forecast_steps[min(self.epoch, len(self.forecast_steps) - 1)]
196+
fsm = (
197+
self.forecast_steps[min(self.epoch, len(self.forecast_steps) - 1)]
198+
if self.forecast_policy != "random"
199+
else self.forecast_steps.max()
200+
)
198201
if fsm > 0:
199202
logger.info(f"forecast_steps at epoch={self.epoch} : {fsm}")
200203

@@ -319,7 +322,7 @@ def __iter__(self):
319322

320323
# collect for all forecast steps
321324
for fstep in range(forecast_dt + 1):
322-
# collect all sources
325+
# collect all targets
323326
for _, ds in enumerate(stream_ds):
324327
step_forecast_dt = (
325328
idx + (self.forecast_delta_hrs * fstep) // self.step_hrs

src/weathergen/model/model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -581,12 +581,12 @@ def forward(self, model_params, batch, forecast_steps):
581581

582582
# roll-out in latent space
583583
preds_all = []
584-
for _ in range(forecast_steps):
584+
for fstep in range(forecast_steps):
585585
# prediction
586586
preds_all += [
587587
self.predict(
588588
model_params,
589-
forecast_steps,
589+
fstep,
590590
tokens,
591591
streams_data,
592592
target_coords_idxs,
@@ -610,7 +610,6 @@ def forward(self, model_params, batch, forecast_steps):
610610

611611
#########################################
612612
def embed_cells(self, model_params, streams_data):
613-
# code.interact( local=locals())
614613
source_tokens_lens = torch.stack(
615614
[
616615
torch.stack(

src/weathergen/train/trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,6 @@ def init(
8585
if self.cf.rank == 0:
8686
path_run.mkdir(exist_ok=True)
8787
path_model.mkdir(exist_ok=True)
88-
# save config
89-
cf.save()
90-
if run_mode == "training":
91-
cf.print()
9288
self.path_run = path_run
9389

9490
self.init_perf_monitoring()
@@ -136,6 +132,9 @@ def evaluate(self, cf, run_id_trained, epoch, run_id_new=False):
136132
for name, w in cf.loss_fcts_val:
137133
self.loss_fcts_val += [[getattr(losses, name), w]]
138134

135+
if self.cf.rank == 0:
136+
self.cf.save()
137+
139138
# evaluate validation set
140139
self.validate(epoch=0)
141140
print(f"Finished evaluation run with id: {cf.run_id}")
@@ -425,6 +424,7 @@ def run(self, cf, private_cf, run_id_contd=None, epoch_contd=None, run_id_new=Fa
425424
torch._dynamo.config.optimize_ddp = False
426425

427426
if self.cf.rank == 0:
427+
self.cf.save()
428428
self.cf.print()
429429

430430
# training loop

0 commit comments

Comments
 (0)