Skip to content

Commit 1175afc

Browse files
committed
Merge branch 'main' into feat/decouple-inference-preparation-and-execution
2 parents 7b2826d + 98a7dfc commit 1175afc

File tree

15 files changed

+151
-154
lines changed

15 files changed

+151
-154
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,25 @@ description: |
3131
dates:
3232
start: 2020-01-01T12:00
3333
end: 2020-01-10T00:00
34-
frequency: 54h
35-
36-
lead_time: 120h
34+
frequency: 60h
3735

3836
runs:
3937
- forecaster:
4038
mlflow_id: 2f962c89ff644ca7940072fa9cd088ec
4139
label: Stage D - N320 global grid with CERRA finetuning
40+
steps: 0/120/6
4241
- forecaster:
4342
mlflow_id: d0846032fc7248a58b089cbe8fa4c511
4443
label: M-1 forecaster
44+
steps: 0/120/6
4545

4646

4747
baselines:
4848
- baseline:
4949
baseline_id: COSMO-E
5050
label: COSMO-E
5151
root: /store_new/mch/msopr/ml/COSMO-E
52-
steps: 0/126/6
52+
steps: 0/120/6
5353

5454
analysis:
5555
label: COSMO KENDA

config/forecasters-co1e.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@ dates:
88
end: 2020-01-10T00:00
99
frequency: 54h
1010

11-
lead_time: 120h
12-
1311
runs:
1412
- forecaster:
1513
mlflow_id: 2174c939c8844555a52843b71219d425
1614
label: Cosmo 1km + era5 N320, finetuned on cerra checkpoint, lam resolution 11
1715
config: resources/inference/configs/forecaster_no_trimedge_fromtraining.yaml
16+
steps: 0/120/6
1817
inference_resources:
1918
gpu: 4
2019
tasks: 4
@@ -24,7 +23,7 @@ baselines:
2423
baseline_id: COSMO-1E
2524
label: COSMO-1E
2625
root: /scratch/mch/bhendj/COSMO-1E
27-
steps: 0/126/6
26+
steps: 0/33/6
2827

2928
analysis:
3029
label: COSMO KENDA

config/forecasters.yaml

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,24 @@
11
# yaml-language-server: $schema=../workflow/tools/config.schema.json
22
description: |
3-
This is an experiment to do blabla.
3+
Evaluate skill of COSMO-E emulator (M-1 forecaster).
44
55
dates:
66
start: 2020-01-01T12:00
77
end: 2020-01-10T00:00
8-
# end: 2020-03-30T00:00
9-
frequency: 36h
10-
11-
lead_time: 120h
8+
frequency: 60h
129

1310
runs:
14-
- forecaster:
15-
mlflow_id: 2f962c89ff644ca7940072fa9cd088ec
16-
label: Stage D - N320 global grid with CERRA finetuning
1711
- forecaster:
1812
mlflow_id: d0846032fc7248a58b089cbe8fa4c511
1913
label: M-1 forecaster
14+
steps: 0/120/6
2015

2116
baselines:
2217
- baseline:
2318
baseline_id: COSMO-E
2419
label: COSMO-E
2520
root: /store_new/mch/msopr/ml/COSMO-E
26-
steps: 0/126/6
21+
steps: 0/120/6
2722

2823
analysis:
2924
label: COSMO KENDA
@@ -38,7 +33,7 @@ locations:
3833
profile:
3934
executor: slurm
4035
global_resources:
41-
gpus: 15
36+
gpus: 16
4237
default_resources:
4338
slurm_partition: "postproc"
4439
cpus_per_task: 1

config/interpolators.yaml

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,48 @@
11
# yaml-language-server: $schema=../workflow/tools/config.schema.json
22
description: |
3-
Stretched interpolator vs LAM interpolator.
3+
Evaluate skill of SGM interpolator (M-2 interpolator).
44
55
dates:
66
start: 2020-01-01T12:00
77
end: 2020-01-10T00:00
8-
frequency: 54h
9-
10-
lead_time: 120h
8+
frequency: 60h
119

1210
runs:
1311
- interpolator:
14-
mlflow_id: 9c18b90074214d769b8b383722fc5a06
15-
label: LAM Interpolator (COSMO-E analysis)
16-
steps: 0/121/1
17-
config: resources/inference/configs/interpolator_from_test_data.yaml
12+
mlflow_id: 8d1e0410ca7d4f74b368b3079878259a
13+
label: M-2 interpolator (KENDA)
14+
steps: 0/120/1
15+
config: resources/inference/configs/interpolator_from_test_data_stretched.yaml
1816
forecaster: null
1917
extra_dependencies:
20-
- git+https://github.com/ecmwf/anemoi-inference@fix/cutout-preprocessors
21-
- torch-geometric==2.6.1
22-
- anemoi-graphs==0.5.2
23-
- interpolator:
24-
mlflow_id: 9c18b90074214d769b8b383722fc5a06
25-
label: LAM Interpolator (M-1 forecaster)
26-
steps: 0/121/1
27-
forecaster:
28-
mlflow_id: d0846032fc7248a58b089cbe8fa4c511
29-
config: resources/inference/configs/forecaster_with_global.yaml
30-
extra_dependencies:
31-
- git+https://github.com/ecmwf/anemoi-inference@fix/cutout-preprocessors
18+
- git+https://github.com/ecmwf/anemoi-inference@14189907b4f4e3b204b7994f828831b8aa51e9b6
3219
- torch-geometric==2.6.1
3320
- anemoi-graphs==0.5.2
3421
- interpolator:
35-
mlflow_id: 07c3d9698db14d859b78bb712a65bbbf
36-
label: SGM Interpolator (M-1 forecaster)
37-
steps: 0/121/1
22+
mlflow_id: 8d1e0410ca7d4f74b368b3079878259a
23+
label: M-2 interpolator (M-1 forecaster)
24+
steps: 0/120/1
3825
config: resources/inference/configs/interpolator_stretched.yaml
3926
forecaster:
4027
mlflow_id: d0846032fc7248a58b089cbe8fa4c511
4128
config: resources/inference/configs/forecaster_with_global.yaml
29+
steps: 0/120/6
4230
extra_dependencies:
43-
- git+https://github.com/ecmwf/anemoi-inference@fix/cutout-preprocessors
31+
- git+https://github.com/ecmwf/anemoi-inference@14189907b4f4e3b204b7994f828831b8aa51e9b6
4432
- torch-geometric==2.6.1
4533
- anemoi-graphs==0.5.2
34+
- forecaster:
35+
mlflow_id: d0846032fc7248a58b089cbe8fa4c511
36+
label: M-1 forecaster
37+
config: resources/inference/configs/forecaster_with_global.yaml
38+
steps: 0/120/6
4639

4740
baselines:
4841
- baseline:
4942
baseline_id: COSMO-E-1h
5043
label: COSMO-E
5144
root: /scratch/mch/bhendj/COSMO-E
52-
steps: 0/121/1
45+
steps: 0/120/1
5346

5447
analysis:
5548
label: COSMO KENDA
@@ -65,7 +58,7 @@ locations:
6558
profile:
6659
executor: slurm
6760
global_resources:
68-
gpus: 15
61+
gpus: 16
6962
default_resources:
7063
slurm_partition: "postproc"
7164
cpus_per_task: 1

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,4 @@ markers = [
5050
packages = [
5151
"src/evalml",
5252
"src/verification"
53-
]
53+
]

src/evalml/config.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pathlib import Path
22
from typing import Dict, List, Any
33

4-
from pydantic import BaseModel, Field, RootModel, HttpUrl
4+
from pydantic import BaseModel, Field, RootModel, HttpUrl, field_validator
55

66
PROJECT_ROOT = Path(__file__).parents[2]
77

@@ -70,9 +70,15 @@ class RunConfig(BaseModel):
7070
None,
7171
description="The label for the run that will be used in experiment results such as reports and figures.",
7272
)
73-
steps: str | None = Field(
74-
None,
75-
description="Forecast steps to be used from interpolator, e.g. '0/126/6'.",
73+
steps: str = Field(
74+
...,
75+
description=(
76+
"Forecast lead times in hours, formatted as 'start/end/step'. "
77+
"The range includes the start lead time and continues with the given step "
78+
"until reaching or exceeding the end lead time. "
79+
"Example: '0/120/6' for lead times every 6 hours up to 120 h, "
80+
"or '0/33/6' up to 30 h."
81+
),
7682
)
7783
extra_dependencies: List[str] = Field(
7884
default_factory=list,
@@ -86,6 +92,27 @@ class RunConfig(BaseModel):
8692

8793
config: Dict[str, Any] | str
8894

95+
@field_validator("steps")
96+
def validate_steps(cls, v: str) -> str:
97+
if "/" not in v:
98+
raise ValueError(
99+
f"Steps must follow the format 'start/stop/step', got '{v}'"
100+
)
101+
parts = v.split("/")
102+
if len(parts) != 3:
103+
raise ValueError("Steps must be formatted as 'start/end/step'.")
104+
try:
105+
start, end, step = map(int, parts)
106+
except ValueError:
107+
raise ValueError("Start, end, and step must be integers.")
108+
if start > end:
109+
raise ValueError(
110+
f"Start ({start}) must be less than or equal to end ({end})."
111+
)
112+
if step <= 0:
113+
raise ValueError(f"Step ({step}) must be a positive integer.")
114+
return v
115+
89116

90117
class ForecasterConfig(RunConfig):
91118
"""Single training run stored in MLflow."""
@@ -240,9 +267,6 @@ class ConfigModel(BaseModel):
240267
description="Description of the experiment, e.g. 'Hindcast of the 2023 season.'",
241268
)
242269
dates: Dates | ExplicitDates
243-
lead_time: str = Field(
244-
..., description="Forecast length, e.g. '120h'", pattern=r"^\d+[hmd]$"
245-
)
246270
runs: List[ForecasterItem | InterpolatorItem] = Field(
247271
...,
248272
description="Dictionary of runs to execute, with run IDs as keys and configurations as values.",

workflow/Snakefile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ rule sandbox_all:
5454
input:
5555
expand(
5656
rules.create_inference_sandbox.output.sandbox,
57-
run_id=collect_all_runs(),
57+
run_id=collect_all_candidates(),
5858
),
5959

6060

@@ -64,7 +64,7 @@ rule run_inference_all:
6464
expand(
6565
OUT_ROOT / "data/runs/{run_id}/{init_time}/raw",
6666
init_time=[t.strftime("%Y%m%d%H%M") for t in REFTIMES],
67-
run_id=collect_all_runs(),
67+
run_id=collect_all_candidates(),
6868
),
6969

7070

@@ -73,7 +73,7 @@ rule verif_metrics_all:
7373
expand(
7474
rules.verif_metrics.output,
7575
init_time=[t.strftime("%Y%m%d%H%M") for t in REFTIMES],
76-
run_id=collect_all_runs(),
76+
run_id=collect_all_candidates(),
7777
),
7878

7979

workflow/rules/common.smk

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,13 @@ REFTIMES = _reftimes()
6767

6868

6969
def collect_all_runs():
70-
"""Collect all runs defined in the configuration."""
70+
"""Collect all runs defined in the configuration, including secondary runs."""
7171
runs = {}
7272
for run_entry in copy.deepcopy(config["runs"]):
7373
model_type = next(iter(run_entry))
7474
run_config = run_entry[model_type]
7575
run_config["model_type"] = model_type
76+
run_config["is_candidate"] = True
7677
run_id = run_config["mlflow_id"][0:9]
7778

7879
if model_type == "interpolator":
@@ -83,6 +84,7 @@ def collect_all_runs():
8384
# Ensure a proper 'forecaster' entry exists with model_type
8485
fore_cfg = copy.deepcopy(run_config["forecaster"])
8586
fore_cfg["model_type"] = "forecaster"
87+
fore_cfg["is_candidate"] = False # exclude from outputs
8688
runs[tail_id] = fore_cfg
8789
run_id = f"{run_id}-{tail_id}"
8890

@@ -91,6 +93,16 @@ def collect_all_runs():
9193
return runs
9294

9395

96+
def collect_all_candidates():
97+
"""Collect participating runs ('candidates') only."""
98+
runs = collect_all_runs()
99+
candidates = {}
100+
for run_id, run_config in runs.items():
101+
if run_config.get("is_candidate", False):
102+
candidates[run_id] = run_config
103+
return candidates
104+
105+
94106
def collect_all_baselines():
95107
"""Collect all baselines defined in the configuration."""
96108
baselines = {}
@@ -107,7 +119,8 @@ def collect_experiment_participants():
107119
for base in BASELINE_CONFIGS.keys():
108120
participants[base] = OUT_ROOT / f"data/baselines/{base}/verif_aggregated.nc"
109121
for exp in RUN_CONFIGS.keys():
110-
participants[exp] = OUT_ROOT / f"data/runs/{exp}/verif_aggregated.nc"
122+
if RUN_CONFIGS[exp].get("is_candidate", False):
123+
participants[exp] = OUT_ROOT / f"data/runs/{exp}/verif_aggregated.nc"
111124
return participants
112125

113126

workflow/rules/data.smk

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@ if "extract_cosmoe" in config.get("include-optional-rules", []):
1818
runtime="24h",
1919
params:
2020
year_postfix=lambda wc: f"FCST{wc.year}",
21-
lead_time="0/126/6",
21+
steps="0/120/6",
2222
log:
2323
OUT_ROOT / "logs/extract-cosmoe-fcts-{year}.log",
2424
shell:
2525
"""
2626
python workflow/scripts/extract_baseline_fct.py \
2727
--archive_dir {input.archive}/{params.year_postfix} \
2828
--output_store {output.fcts} \
29-
--lead_time {params.lead_time} \
29+
--steps {params.steps} \
3030
> {log} 2>&1
3131
"""
3232

@@ -45,14 +45,14 @@ if "extract_cosmo1e" in config.get("include-optional-rules", []):
4545
runtime="24h",
4646
params:
4747
year_postfix=lambda wc: f"FCST{wc.year}",
48-
lead_time="0/34/1",
48+
steps="0/33/1",
4949
log:
5050
OUT_ROOT / "logs/extract-cosmo1e-fcts-{year}.log",
5151
shell:
5252
"""
5353
python workflow/scripts/extract_baseline_fct.py \
5454
--archive_dir {input.archive}/{params.year_postfix} \
5555
--output_store {output.fcts} \
56-
--lead_time {params.lead_time} \
56+
--steps {params.steps} \
5757
> {log} 2>&1
5858
"""

workflow/rules/inference.smk

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,12 @@ def get_resource(wc, field: str, default):
135135
return getattr(rc["inference_resources"], field) or default
136136

137137

138+
def get_leadtime(wc):
139+
"""Get the lead time from the run config."""
140+
start, end, step = RUN_CONFIGS[wc.run_id]["steps"].split("/")
141+
return f"{end}h"
142+
143+
138144
rule prepare_inference_forecaster:
139145
localrule: True
140146
input:
@@ -151,7 +157,7 @@ rule prepare_inference_forecaster:
151157
checkpoints_path=parse_input(
152158
input.pyproject, parse_toml, key="tool.anemoi.checkpoints_path"
153159
),
154-
lead_time=config["lead_time"],
160+
lead_time=lambda wc: get_leadtime(wc),
155161
output_root=(OUT_ROOT / "data").resolve(),
156162
resources_root=Path("resources/inference").resolve(),
157163
reftime_to_iso=lambda wc: datetime.strptime(
@@ -235,7 +241,7 @@ rule prepare_inference_interpolator:
235241
checkpoints_path=parse_input(
236242
input.pyproject, parse_toml, key="tool.anemoi.checkpoints_path"
237243
),
238-
lead_time=config["lead_time"],
244+
lead_time=lambda wc: get_leadtime(wc),
239245
output_root=(OUT_ROOT / "data").resolve(),
240246
resources_root=Path("resources/inference").resolve(),
241247
reftime_to_iso=lambda wc: datetime.strptime(

0 commit comments

Comments
 (0)