Skip to content

Commit c4afa5c

Browse files
authored
[219] Ensures the run_id passed through the command line are propagated (ecmwf#252)
* changes * cleanup * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * fixes * fixes * fixes * fixes
1 parent 3f97922 commit c4afa5c

File tree

7 files changed

+117
-22
lines changed

7 files changed

+117
-22
lines changed

integration_tests/small1.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,12 @@ def test_train(setup, test_run_id):
5555
)
5656

5757
evaluate_from_args(
58-
"-start 2022-10-10 -end 2022-10-11 --samples 10 --same_run_id --epoch 0".split()
58+
"-start 2022-10-10 -end 2022-10-11 --samples 10 --epoch 0".split()
5959
+ [
6060
"--run_id",
6161
test_run_id,
62+
"--eval_run_id",
63+
test_run_id,
6264
"--config",
6365
f"{weathergen_home}/integration_tests/small1.yaml",
6466
]

src/weathergen/__init__.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,11 @@ def evaluate_from_args(argl: list[str]):
9191
help="Path to private configuration file for paths.",
9292
)
9393
parser.add_argument(
94-
"-n",
95-
"--same_run_id",
94+
"--eval_run_id",
95+
type=str,
9696
required=False,
97-
dest="run_id_new",
98-
action="store_false",
99-
help="store evaluation results in the same folder as run_id",
97+
dest="eval_run_id",
98+
help="(optional) if specified, uses the provided run id to store the evaluation results",
10099
)
101100
parser.add_argument(
102101
"--config",
@@ -133,7 +132,7 @@ def evaluate_from_args(argl: list[str]):
133132
cf.loader_num_workers = min(cf.loader_num_workers, args.samples)
134133

135134
trainer = Trainer()
136-
trainer.evaluate(cf, args.run_id, args.epoch, args.run_id_new)
135+
trainer.evaluate(cf, args.run_id, args.epoch, run_id_new=args.eval_run_id)
137136

138137

139138
####################################################################################################
@@ -273,7 +272,7 @@ def train_with_args(argl: list[str], stream_dir: str | None):
273272
trainer = Trainer(checkpoint_freq=250, print_freq=10)
274273

275274
try:
276-
trainer.run(cf)
275+
trainer.run(cf, run_id_new=args.run_id)
277276
except Exception:
278277
extype, value, tb = sys.exc_info()
279278
traceback.print_exc()

src/weathergen/train/trainer.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from weathergen.train.trainer_base import Trainer_Base
2929
from weathergen.train.utils import get_run_id
3030
from weathergen.utils.config import Config
31+
from weathergen.utils.distributed import is_root
3132
from weathergen.utils.train_logger import TrainLogger
3233
from weathergen.utils.validation_io import write_validation
3334

@@ -48,12 +49,12 @@ def init(
4849
cf: Config,
4950
run_id_contd=None,
5051
epoch_contd=None, # unused
51-
run_id_new=False,
52+
run_id_new: bool | str | None = False,
5253
run_mode="training", # unused
5354
):
5455
self.cf = cf
5556

56-
if isinstance(run_id_new, str):
57+
if run_id_new is not None and isinstance(run_id_new, str):
5758
cf.run_id = run_id_new
5859
elif run_id_new or cf.run_id is None:
5960
cf.run_id = get_run_id()
@@ -64,6 +65,7 @@ def init(
6465
assert cf.samples_per_epoch % cf.batch_size == 0
6566
assert cf.samples_per_validation % cf.batch_size_validation == 0
6667

68+
_logger.info(f"Starting run with id: {cf.run_id}")
6769
self.devices = self.init_torch()
6870

6971
self.init_ddp(cf)
@@ -82,7 +84,6 @@ def init(
8284
self.path_run = path_run
8385

8486
self.init_perf_monitoring()
85-
8687
self.train_logger = TrainLogger(cf, self.path_run)
8788

8889
###########################################
@@ -134,7 +135,7 @@ def evaluate(self, cf, run_id_trained, epoch, run_id_new=False):
134135
_logger.info(f"Finished evaluation run with id: {cf.run_id}")
135136

136137
###########################################
137-
def run(self, cf, run_id_contd=None, epoch_contd=None, run_id_new=False):
138+
def run(self, cf, run_id_contd=None, epoch_contd=None, run_id_new: bool | str = False):
138139
# general initalization
139140
self.init(cf, run_id_contd, epoch_contd, run_id_new)
140141

@@ -169,6 +170,7 @@ def run(self, cf, run_id_contd=None, epoch_contd=None, run_id_new=False):
169170
self.model = Model(cf, sources_size, targets_num_channels, targets_coords_size).create()
170171
# load model if specified
171172
if run_id_contd is not None:
173+
_logger.info(f"Continuing run with id={run_id_contd} at epoch {epoch_contd}.")
172174
self.model.load(run_id_contd, epoch_contd)
173175
_logger.info(f"Loaded model id={run_id_contd}.")
174176

@@ -278,7 +280,7 @@ def run(self, cf, run_id_contd=None, epoch_contd=None, run_id_new=False):
278280
if cf.forecast_policy is not None:
279281
torch._dynamo.config.optimize_ddp = False
280282

281-
if self.cf.rank == 0:
283+
if is_root():
282284
config.save(self.cf, None)
283285
config.print_cf(self.cf)
284286

@@ -674,7 +676,7 @@ def save_model(self, epoch: int, name=None):
674676
else:
675677
state = self.ddp_model.state_dict()
676678

677-
if self.cf.rank == 0:
679+
if is_root():
678680
filename = "".join(
679681
[
680682
self.cf.run_id,

src/weathergen/train/trainer_base.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,18 @@
88
# nor does it submit to any jurisdiction.
99

1010
import datetime
11+
import errno
1112
import logging
1213
import os
14+
import socket
1315

1416
import pynvml
1517
import torch
1618
import torch.distributed as dist
1719
import torch.multiprocessing
18-
import torch.utils.data.distributed
1920

2021
from weathergen.train.utils import str_to_tensor, tensor_to_str
22+
from weathergen.utils.distributed import is_root
2123

2224
_logger = logging.getLogger(__name__)
2325

@@ -70,27 +72,61 @@ def init_ddp(cf):
7072
cf.with_ddp = False
7173
cf.rank = rank
7274
cf.num_ranks = num_ranks
75+
_logger.info(
76+
"DDP not initialized. MASTER_ADDR not set. Running in single process mode."
77+
)
7378
return
7479

7580
local_rank = int(os.environ.get("SLURM_LOCALID"))
7681
ranks_per_node = int(os.environ.get("SLURM_TASKS_PER_NODE", "1")[0])
7782
rank = int(os.environ.get("SLURM_NODEID")) * ranks_per_node + local_rank
7883
num_ranks = int(os.environ.get("SLURM_NTASKS"))
84+
_logger.info(
85+
f"DDP initialization: local_rank={local_rank}, ranks_per_node={ranks_per_node}, rank={rank}, num_ranks={num_ranks}"
86+
)
87+
88+
if rank == 0:
89+
# Check that port 1345 is available, raise an error if not
90+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
91+
try:
92+
s.bind((master_node, 1345))
93+
except OSError as e:
94+
if e.errno == errno.EADDRINUSE:
95+
_logger.error(
96+
f"Port 1345 is already in use on {master_node}. Please check your network configuration."
97+
)
98+
raise
99+
else:
100+
_logger.error(f"Error while binding to port 1345 on {master_node}: {e}")
101+
raise
102+
103+
_logger.info(
104+
f"Initializing DDP with rank {rank} out of {num_ranks} on master_node:{master_node}."
105+
)
79106

80107
dist.init_process_group(
81108
backend="nccl",
82109
init_method="tcp://" + master_node + ":1345",
83-
timeout=datetime.timedelta(seconds=10 * 8192),
110+
timeout=datetime.timedelta(seconds=240),
84111
world_size=num_ranks,
85112
rank=rank,
113+
device_id=torch.device("cuda", local_rank),
86114
)
115+
if is_root():
116+
_logger.info("DDP initialized: root.")
117+
# Wait for all ranks to reach this point
118+
dist.barrier()
87119

88120
# communicate run id to all nodes
89-
run_id_int = torch.zeros(8, dtype=torch.int32).cuda()
90-
if rank == 0:
121+
len_run_id = len(cf.run_id)
122+
run_id_int = torch.zeros(len_run_id, dtype=torch.int32).cuda()
123+
if is_root():
124+
_logger.info(f"Communicating run_id to all nodes: {cf.run_id}")
91125
run_id_int = str_to_tensor(cf.run_id).cuda()
92126
dist.all_reduce(run_id_int, op=torch.distributed.ReduceOp.SUM)
93-
cf.run_id = tensor_to_str(run_id_int)
127+
if not is_root():
128+
cf.run_id = tensor_to_str(run_id_int)
129+
_logger.info(f"rank: {rank} has run_id: {cf.run_id}")
94130

95131
# communicate data_loader_rng_seed
96132
if hasattr(cf, "data_loader_rng_seed"):
@@ -101,6 +137,7 @@ def init_ddp(cf):
101137
dist.all_reduce(l_seed, op=torch.distributed.ReduceOp.SUM)
102138
cf.data_loader_rng_seed = l_seed.item()
103139

140+
# TODO: move outside of the config
104141
cf.rank = rank
105142
cf.num_ranks = num_ranks
106143
cf.with_ddp = True

src/weathergen/utils/config.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,26 @@ def _load_private_conf(private_home: Path | None) -> OmegaConf:
127127
_logger.info(f"Loading private config fromWEATHERGEN_PRIVATE_CONF:{private_home}.")
128128

129129
elif env_script_path.is_file():
130+
_logger.info(f"Loading private config from platform-env.py: {env_script_path}.")
131+
# This code does many checks to ensure that any error message is surfaced. Since it is a process call,
132+
# it can be hard to diagnose the error.
133+
# TODO: eventually, put all this wrapper code in a separate function
134+
try:
135+
result_hpc = subprocess.run(
136+
[str(env_script_path), "hpc"], capture_output=True, text=True, check=True
137+
)
138+
except subprocess.CalledProcessError as e:
139+
_logger.error(
140+
f"Error while running platform-env.py: {e} {e.stderr} {e.stdout} {e.output} {e.returncode}"
141+
)
142+
raise
143+
if result_hpc.returncode != 0:
144+
_logger.error(f"Error while running platform-env.py: {result_hpc.stderr.strip()}")
145+
raise RuntimeError(f"Error while running platform-env.py: {result_hpc.stderr.strip()}")
146+
_logger.info(f"Detected HPC: {result_hpc.stdout.strip()}.")
147+
130148
result = subprocess.run(
131-
[str(env_script_path), "hpc-config"], capture_output=True, text=True
149+
[str(env_script_path), "hpc-config"], capture_output=True, text=True, check=True
132150
)
133151
private_home = Path(result.stdout.strip())
134152
_logger.info(f"Loading private config from platform-env.py output: {private_home}.")
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""
2+
Utilities for writing distributed pytorch-based code.
3+
4+
This module is adapted from code by Seb Hoffamn at:
5+
https://github.com/sehoffmann/dmlcloud/blob/develop/dmlcloud/core/distributed.py
6+
7+
(same license as the rest of the code)
8+
Copyright (c) 2025, Sebastian Hoffmann
9+
"""
10+
11+
# TODO: copy other utilities from dmlcloud such as root_wrap etc.
12+
# TODO: move the DDP code from trainer.py to this file
13+
14+
import torch.distributed as dist
15+
16+
SYNC_TIMEOUT_SEC = 60 * 60 # 1 hour
17+
18+
19+
def is_root(pg: dist.ProcessGroup | None = None) -> bool:
20+
"""
21+
Check if the current rank is the root rank (rank 0).
22+
23+
Args:
24+
group (ProcessGroup, optional): The process group to work on. If None (default), the default process group will be used.
25+
"""
26+
if not _is_distributed_initialized():
27+
# If not initialized, it assumed to be in single process mode.
28+
# TODO: check what should happen if a process group is passed
29+
return True
30+
return dist.get_rank(pg) == 0
31+
32+
33+
def _is_distributed_initialized():
34+
return dist.is_available() and dist.is_initialized()

src/weathergen/utils/logger.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,12 @@ def init_loggers():
3131
3232
WARNING: this function resets all the logging handlers.
3333
34-
This function can be called only once, so that it can be called repeatedly in multiprocessing pipelines.
34+
This function follows a singleton pattern, it will only operate once per process
35+
and will be a no-op if called again.
3536
"""
36-
formatter = RelPathFormatter("%(pathname)s:%(lineno)d : %(levelname)-8s : %(message)s")
37+
formatter = RelPathFormatter(
38+
"%(asctime)s %(pathname)s:%(lineno)d : %(levelname)-8s : %(message)s"
39+
)
3740
for package in ["obslearn", "weathergen"]:
3841
logger = logging.getLogger(package)
3942
logger.handlers.clear()

0 commit comments

Comments
 (0)