Skip to content

Commit 8022b35

Browse files
kctezcanKerem Tezcan
andauthored
Hard coded model and results paths ->private_config - Ktezcan/develop/iss49 paths model run (ecmwf#83)
* changed model and run paths from hardcoded to private config * dummy commit for email * corrected some mistakes * fixing/testing loading config * fixed and tested config loader * reverted date reqs * cleanup * fixed more paths in the trainer.py file * small fix and more cleanups * removed unused import * converted to pathlib in model.py and trainer.py * converted paths to pathlib in config.py and train_logger.py * fixed: code expecting string path, got PAth object * reverted hanges allowing the cod to run * removed the #KCT:path comment * implement Tim's clarification * still removing KCT:path comments * corrected the typo again * removed the KCT:path comment yet again * changed from print to logger.info() * corrected mistake in config.py --------- Co-authored-by: Kerem Tezcan <ktezcan@balfrin-ln004.cscs.ch>
1 parent fb6c2e5 commit 8022b35

File tree

5 files changed

+50
-20
lines changed

5 files changed

+50
-20
lines changed

src/weathergen/__init__.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,27 @@ def evaluate():
8888
default=["ERA5"],
8989
help="Analysis output streams during evaluation.",
9090
)
91+
parser.add_argument(
92+
"--private_config",
93+
type=str,
94+
default=None,
95+
help="Path to private configuration file for paths.",
96+
)
9197

9298
args = parser.parse_args()
9399

100+
# get the paths from the private config
101+
private_cf = load_private_conf(args.private_config)
102+
94103
# TODO: move somewhere else
95104
init_loggers()
96105

97-
# load config if specified
98-
cf = Config.load(args.run_id, args.epoch)
106+
# load config: if run_id is full path, it loads from there
107+
cf = Config.load(args.run_id, args.epoch, private_cf["model_path"])
108+
109+
# add parameters from private (paths) config
110+
for k, v in private_cf.items():
111+
setattr(cf, k, v)
99112

100113
cf.run_history += [(cf.run_id, cf.istep)]
101114

src/weathergen/model/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ def print_num_parameters(self):
537537

538538
#########################################
539539
def load(self, run_id, epoch=None):
540-
path_run = Path("./models/") / run_id
540+
path_run = Path(self.cf.model_path) / run_id
541541
epoch_id = f"epoch{epoch:05d}" if epoch is not None else "latest"
542542
filename = f"{run_id}_{epoch_id}.chkpt"
543543

src/weathergen/train/trainer.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ def init(
7878
cf = self.init_streams(cf, run_id_contd)
7979

8080
# create output directory
81-
path_run = Path("./results") / cf.run_id
82-
path_model = Path("./models") / cf.run_id
81+
path_run = Path(cf.run_path) / cf.run_id
82+
path_model = Path(cf.model_path) / cf.run_id
8383
if self.cf.rank == 0:
8484
path_run.mkdir(exist_ok=True)
8585
path_model.mkdir(exist_ok=True)
@@ -790,12 +790,6 @@ def batch_to_device(self, batch):
790790

791791
###########################################
792792
def save_model(self, epoch=-1, name=None):
793-
path_model = Path("./models/") / self.cf.run_id
794-
epoch_str = "latest" if epoch == -1 else f"epoch{epoch:05d}"
795-
name_str = f"_{name}" if name is not None else ""
796-
file_out = path_model / f"{self.cf.run_id}_{epoch_str}_{name_str}.chkpt"
797-
temp_file_out = path_model / f"{self.cf.run_id}_{epoch_str}_{name_str}_temp.chkpt"
798-
799793
if self.cf.with_ddp and self.cf.with_fsdp:
800794
_cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
801795
with FSDP.state_dict_type(
@@ -808,10 +802,22 @@ def save_model(self, epoch=-1, name=None):
808802
state = self.ddp_model.state_dict()
809803

810804
if self.cf.rank == 0:
805+
filename = "".join(
806+
[
807+
self.cf.run_id,
808+
"_",
809+
"latest" if epoch == -1 else f"epoch{epoch:05d}",
810+
("_" + name) if name is not None else "",
811+
]
812+
)
813+
base_path = Path(self.cf.model_path) / self.cf.run_id
814+
file_out: Path = base_path / (filename + ".chkpt")
815+
file_tmp: Path = base_path / (filename + "_tmp.chkpt")
811816
# save temp file (slow)
812-
torch.save(state, temp_file_out)
817+
torch.save(state, file_tmp)
813818
# move file (which is changing the link in the file system and very fast)
814-
temp_file_out.replace(file_out)
819+
file_tmp.replace(file_out)
820+
815821
# save config
816822
self.cf.save(epoch)
817823

src/weathergen/utils/config.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
import yaml
1515

16+
from weathergen.utils.logger import logger
17+
1618

1719
###########################################
1820
class Config:
@@ -30,7 +32,7 @@ def print(self):
3032
print("{}{} : {}".format("" if k == "reportypes" else " ", k, v))
3133

3234
def save(self, epoch=None):
33-
path_models = Path("./models")
35+
path_models = Path(self.model_path)
3436
# save in directory with model files
3537
dirname = path_models / self.run_id
3638
dirname.mkdir(exist_ok=True, parents=True)
@@ -41,21 +43,29 @@ def save(self, epoch=None):
4143
fname = dirname / f"model_{self.run_id}{epoch_str}.json"
4244

4345
json_str = json.dumps(self.__dict__)
44-
with open(fname, "w") as f:
46+
with fname.open("w") as f:
4547
f.write(json_str)
4648

4749
@staticmethod
48-
def load(run_id, epoch=None):
49-
if "/" in run_id: # assumed to be full path instead of just id
50+
def load(run_id: str, epoch: int = None, model_path: str = "./models") -> "Config":
51+
"""
52+
Load a configuration file from a given run_id and epoch.
53+
If run_id is a full path, loads it from the full path.
54+
"""
55+
if Path(run_id).exists(): # load from the full path if a full path is provided
5056
fname = Path(run_id)
57+
logger.info(f"Loading config from provided full run_id path: {fname}")
5158
else:
52-
path_models = Path("./models")
59+
path_models = Path(model_path)
5360
epoch_str = ""
5461
if epoch is not None:
5562
epoch_str = "_latest" if epoch == -1 else f"_epoch{epoch:05d}"
5663
fname = path_models / run_id / f"model_{run_id}{epoch_str}.json"
5764

58-
with open(fname) as f:
65+
logger.info(f"Loading config from specified run_id and epoch: {fname}")
66+
67+
# open the file and read into a config object
68+
with fname.open() as f:
5969
json_str = f.readlines()
6070

6171
cf = Config()

src/weathergen/utils/train_logger.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def log_metrics(self, stage: Stage, metrics: dict[str, float]) -> None:
5858

5959
# TODO: performance: we repeatedly open the file for each call. Better for multiprocessing
6060
# but we can probably do better and rely for example on the logging module.
61+
6162
with open(self.path_run / "metrics.json", "ab") as f:
6263
s = json.dumps(clean_metrics) + "\n"
6364
f.write(s.encode("utf-8"))
@@ -139,7 +140,7 @@ def read(run_id, epoch=-1):
139140
cf = Config.load(run_id, epoch)
140141
run_id = cf.run_id
141142

142-
result_dir = Path(f"./results/{run_id}")
143+
result_dir = Path(cf.run_path) / run_id
143144
fname_log_train = result_dir / f"{run_id}_train_log.txt"
144145
fname_log_val = result_dir / f"{run_id}_val_log.txt"
145146
fname_perf_val = result_dir / f"{run_id}_perf_log.txt"

0 commit comments

Comments
 (0)