Skip to content

Commit 05d3276

Browse files
authored
Distinguish between primary runs ('candidates') and secondary runs (#64)
* Distinguish between primary runs ('candidates') and secondary runs * Docstrings
1 parent c8fdc47 commit 05d3276

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

workflow/Snakefile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ rule sandbox_all:
5454
input:
5555
expand(
5656
rules.create_inference_sandbox.output.sandbox,
57-
run_id=collect_all_runs(),
57+
run_id=collect_all_candidates(),
5858
),
5959

6060

@@ -64,7 +64,7 @@ rule run_inference_all:
6464
expand(
6565
OUT_ROOT / "data/runs/{run_id}/{init_time}/raw",
6666
init_time=[t.strftime("%Y%m%d%H%M") for t in REFTIMES],
67-
run_id=collect_all_runs(),
67+
run_id=collect_all_candidates(),
6868
),
6969

7070

@@ -73,7 +73,7 @@ rule verif_metrics_all:
7373
expand(
7474
rules.verif_metrics.output,
7575
init_time=[t.strftime("%Y%m%d%H%M") for t in REFTIMES],
76-
run_id=collect_all_runs(),
76+
run_id=collect_all_candidates(),
7777
),
7878

7979

workflow/rules/common.smk

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,13 @@ REFTIMES = _reftimes()
6666

6767

6868
def collect_all_runs():
69-
"""Collect all runs defined in the configuration."""
69+
"""Collect all runs defined in the configuration, including secondary runs."""
7070
runs = {}
7171
for run_entry in copy.deepcopy(config["runs"]):
7272
model_type = next(iter(run_entry))
7373
run_config = run_entry[model_type]
7474
run_config["model_type"] = model_type
75+
run_config["is_candidate"] = True
7576
run_id = run_config["mlflow_id"][0:9]
7677

7778
if model_type == "interpolator":
@@ -82,6 +83,7 @@ def collect_all_runs():
8283
# Ensure a proper 'forecaster' entry exists with model_type
8384
fore_cfg = copy.deepcopy(run_config["forecaster"])
8485
fore_cfg["model_type"] = "forecaster"
86+
fore_cfg["is_candidate"] = False # exclude from outputs
8587
runs[tail_id] = fore_cfg
8688
run_id = f"{run_id}-{tail_id}"
8789

@@ -90,6 +92,16 @@ def collect_all_runs():
9092
return runs
9193

9294

95+
def collect_all_candidates():
96+
"""Collect participating runs ('candidates') only."""
97+
runs = collect_all_runs()
98+
candidates = {}
99+
for run_id, run_config in runs.items():
100+
if run_config.get("is_candidate", False):
101+
candidates[run_id] = run_config
102+
return candidates
103+
104+
93105
def collect_all_baselines():
94106
"""Collect all baselines defined in the configuration."""
95107
baselines = {}
@@ -106,7 +118,8 @@ def collect_experiment_participants():
106118
for base in BASELINE_CONFIGS.keys():
107119
participants[base] = OUT_ROOT / f"data/baselines/{base}/verif_aggregated.nc"
108120
for exp in RUN_CONFIGS.keys():
109-
participants[exp] = OUT_ROOT / f"data/runs/{exp}/verif_aggregated.nc"
121+
if RUN_CONFIGS[exp].get("is_candidate", False):
122+
participants[exp] = OUT_ROOT / f"data/runs/{exp}/verif_aggregated.nc"
110123
return participants
111124

112125

0 commit comments

Comments
 (0)