Skip to content

Commit 3816abd

Browse files
tjhunterclessig
andauthored
[569] Load eagerly the stream content in order (ecmwf#585)
* changes * change * changes * Remove loading of streams also from inference. --------- Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>
1 parent b035c85 commit 3816abd

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

src/weathergen/run_train.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def inference_from_args(argl: list[str]):
6161
cf = config.set_run_id(cf, args.run_id, args.reuse_run_id)
6262

6363
cf.run_history += [(args.from_run_id, cf.istep)]
64-
cf.streams = config.load_streams(Path(cf.streams_directory))
6564
cf = config.set_paths(cf)
6665

6766
trainer = Trainer()
@@ -115,7 +114,6 @@ def train_continue() -> None:
115114

116115
# track history of run to ensure traceability of results
117116
cf.run_history += [(args.from_run_id, cf.istep)]
118-
cf.streams = config.load_streams(Path(cf.streams_directory))
119117
cf = config.set_paths(cf)
120118

121119
if args.finetune_forecast:

src/weathergen/utils/config.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,14 @@ def load_config(
121121
# all the paths may be concatenated with ":"
122122
p = str(overwrite).split(":")
123123
for path in p:
124-
overwrite_configs.append(_load_overwrite_conf(Path(path)))
124+
c = _load_overwrite_conf(Path(path))
125+
c = _load_streams_in_config(c)
126+
overwrite_configs.append(c)
125127
else:
126128
# If it is a dict or DictConfig, we can directly use it
127-
overwrite_configs.append(_load_overwrite_conf(overwrite))
129+
c = _load_overwrite_conf(overwrite)
130+
c = _load_streams_in_config(c)
131+
overwrite_configs.append(c)
128132

129133
if from_run_id is None:
130134
base_config = _load_default_conf()
@@ -135,6 +139,22 @@ def load_config(
135139
return OmegaConf.merge(base_config, private_config, *overwrite_configs)
136140

137141

142+
def _load_streams_in_config(config: Config) -> Config:
143+
"""If the config contains a streams_directory, loads the streams and returns the config with
144+
the streams set."""
145+
streams_directory = config.get("streams_directory", None)
146+
config = config.copy()
147+
if streams_directory is not None:
148+
streams_directory = Path(streams_directory)
149+
if not streams_directory.is_dir():
150+
msg = f"Streams directory {streams_directory} does not exist."
151+
raise FileNotFoundError(msg)
152+
153+
_logger.info(f"Loading streams from {streams_directory}")
154+
config.streams = load_streams(streams_directory)
155+
return config
156+
157+
138158
def set_run_id(config: Config, run_id: str | None, reuse_run_id: bool) -> Config:
139159
"""
140160
Determine and set run_id of current run.

0 commit comments

Comments
 (0)