Skip to content

Commit c3c5baa

Browse files
authored
Mk/develop/fix plot train 727 (ecmwf#738)
* Load model_path from private config if not provided * Use existing function to get private model path * Incorporated PR comments
1 parent 7536837 commit c3c5baa

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

src/weathergen/utils/config.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,17 @@ def load_model_config(run_id: str, epoch: int | None, model_path: str | None) ->
6767
fname = Path(run_id)
6868
_logger.info(f"Loading config from provided full run_id path: {fname}")
6969
else:
70-
# Load model config here...
70+
# Load model config here. In case model_path is not provided, get it from private conf
71+
if model_path is None:
72+
pconf = _load_private_conf()
73+
model_path = _get_config_attribute(
74+
config=pconf, attribute_name="model_path", fallback="models"
75+
)
7176
model_path = Path(model_path)
7277
fname = model_path / run_id / _get_model_config_file_name(run_id, epoch)
78+
assert fname.exists(), (
79+
"The fallback path to the model does not exist. Please provide a `model_path`."
80+
)
7381

7482
_logger.info(f"Loading config from specified run_id and epoch: {fname}")
7583

@@ -235,7 +243,7 @@ def _load_overwrite_conf(overwrite: Path | dict | DictConfig) -> DictConfig:
235243
return overwrite_config
236244

237245

238-
def _load_private_conf(private_home: Path | None) -> DictConfig:
246+
def _load_private_conf(private_home: Path | None = None) -> DictConfig:
239247
"Return the private configuration."
240248
"If none, take it from the environment variable WEATHERGEN_PRIVATE_CONF."
241249

@@ -246,7 +254,7 @@ def _load_private_conf(private_home: Path | None) -> DictConfig:
246254

247255
elif "WEATHERGEN_PRIVATE_CONF" in os.environ:
248256
private_home = Path(os.environ["WEATHERGEN_PRIVATE_CONF"])
249-
_logger.info(f"Loading private config fromWEATHERGEN_PRIVATE_CONF:{private_home}.")
257+
_logger.info(f"Loading private config from WEATHERGEN_PRIVATE_CONF:{private_home}.")
250258

251259
elif env_script_path.is_file():
252260
_logger.info(f"Loading private config from platform-env.py: {env_script_path}.")
@@ -353,9 +361,9 @@ def set_paths(config: Config) -> Config:
353361

354362

355363
def _get_config_attribute(config: Config, attribute_name: str, fallback: str) -> str:
356-
"""Get an attribute from a Config. If not, fall back to path_shared_working_dir concatenated
357-
with the desired fallback path. Raise an error if neither the attribute nor
358-
is specified."""
364+
"""Get an attribute from a Config. If not available, fall back to path_shared_working_dir
365+
concatenated with the desired fallback path. Raise an error if neither the attribute nor a
366+
fallback is specified."""
359367
attribute = OmegaConf.select(config, attribute_name)
360368
fallback_root = OmegaConf.select(config, "path_shared_working_dir")
361369
assert attribute is not None or fallback_root is not None, (

0 commit comments

Comments
 (0)