Skip to content

Commit 6322968

Browse files
JubekuJulian Kuehnert
andauthored
Continue training through slurm script (ecmwf#395)
* train_continue via slurm * using __main__ as entry point for slurm script * reverting config files to match base branch * reverting config files to match base branch * removing param_sum control logging before and after loading of model weights * run ruff * check whether from_run_id is in arguments * trigger PR check * remove block to set reuse_run_id=True --------- Co-authored-by: Julian Kuehnert <julian.kuehnert@ecwmf.int>
1 parent 54244de commit 6322968

File tree

3 files changed

+19
-6
lines changed

3 files changed

+19
-6
lines changed

config/default_config.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,4 +131,3 @@ run_id: ???
131131
train_log:
132132
# The period to log metrics (in number of batch steps)
133133
log_interval: 20
134-

src/weathergen/run_train.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,5 +173,9 @@ def train_with_args(argl: list[str], stream_dir: str | None):
173173

174174

175175
if __name__ == "__main__":
176-
train()
177-
# train_continue()
176+
# Entry point for slurm script.
177+
# Check whether --from_run_id passed as argument.
178+
if "--from_run_id" in sys.argv:
179+
train_continue()
180+
else:
181+
train()

src/weathergen/train/trainer.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,13 @@
1616
import tqdm
1717
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
1818
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
19-
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision, ShardingStrategy
20-
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy # default_auto_wrap_policy,
19+
from torch.distributed.fsdp.fully_sharded_data_parallel import (
20+
MixedPrecision,
21+
ShardingStrategy,
22+
)
23+
from torch.distributed.fsdp.wrap import (
24+
size_based_auto_wrap_policy, # default_auto_wrap_policy,
25+
)
2126

2227
import weathergen.train.loss as losses
2328
import weathergen.utils.config as config
@@ -377,7 +382,12 @@ def compute_loss(
377382
# assert len(targets_rt) == len(preds) and len(preds) == len(self.cf.streams)
378383
for fstep in range(len(targets_rt)):
379384
for i_obs, (target, target_coords, si) in enumerate(
380-
zip(targets_rt[fstep], targets_coords_rt[fstep], self.cf.streams, strict=False)
385+
zip(
386+
targets_rt[fstep],
387+
targets_coords_rt[fstep],
388+
self.cf.streams,
389+
strict=False,
390+
)
381391
):
382392
pred = preds[fstep][i_obs]
383393

0 commit comments

Comments
 (0)