Skip to content

Commit 5e9c3ef

Browse files
tjhunterJubeku
andauthored
[1092] Adds pushing metrics to the evaluation pipeline (ecmwf#1127)
* changes * changes * changes * changes * changes * scores successfully pushed to MLFlow, still need to refactor * try to batch upload all metrics form same runid * batch logging all scores of each run_id * get parent_run by from_run_id * changes * cleanups * bug fixes * typing issue * Cleanup * pdb * integration test --------- Co-authored-by: Jubeku <julian.kuehnert@ecmwf.int>
1 parent b4fc1a2 commit 5e9c3ef

File tree

12 files changed

+701
-10
lines changed

12 files changed

+701
-10
lines changed

integration_tests/small1_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ def evaluate_results(run_id):
134134
},
135135
}
136136
)
137-
evaluate_from_config(cfg)
137+
# Not passing the mlflow client for tests.
138+
evaluate_from_config(cfg, None)
138139

139140

140141
def load_metrics(run_id):

packages/common/src/weathergen/common/io.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def combine(cls, others: list["IOReaderData"]) -> "IOReaderData":
8383
8484
others is list of ReaderData instances.
8585
"""
86-
8786
assert len(others) > 0, len(others)
8887

8988
other = others[0]
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""
2+
Platform environment configuration for WeatherGenerator.
3+
4+
These are loaded from secrets in the private repository.
5+
"""
6+
7+
import importlib
8+
import importlib.util
9+
from functools import lru_cache
10+
from typing import Protocol
11+
12+
from weathergen.common.config import _REPO_ROOT
13+
14+
15+
class PlatformEnv(Protocol):
16+
"""
17+
Interface for platform environment configuration.
18+
"""
19+
20+
def get_hpc(self) -> str | None: ...
21+
22+
def get_hpc_user(self) -> str | None: ...
23+
24+
def get_hpc_config(self) -> str | None: ...
25+
26+
def get_hpc_certificate(self) -> str | None: ...
27+
28+
29+
@lru_cache(maxsize=1)
30+
def get_platform_env() -> PlatformEnv:
31+
"""
32+
Loads the platform environment module from the private repository.
33+
"""
34+
env_script_path = _REPO_ROOT.parent / "WeatherGenerator-private" / "hpc" / "platform-env.py"
35+
spec = importlib.util.spec_from_file_location("platform_env", env_script_path)
36+
platform_env = importlib.util.module_from_spec(spec)
37+
spec.loader.exec_module(platform_env) # type: ignore
38+
return platform_env # type: ignore

packages/evaluate/pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@ dependencies = [
1010
"xhistogram",
1111
"panel",
1212
"omegaconf",
13-
"weathergen-common",
1413
"plotly>=6.2.0",
14+
"weathergen-common",
15+
"weathergen-metrics",
1516
]
1617

1718
[dependency-groups]

packages/evaluate/src/weathergen/evaluate/plot_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def collect_streams(runs: dict):
3030
return sorted({s for run in runs.values() for s in run["streams"].keys()})
3131

3232

33-
def collect_channels(scores_dict: dict, metric: str, region: str, runs) -> dict:
33+
def collect_channels(scores_dict: dict, metric: str, region: str, runs) -> list[str]:
3434
"""Get all unique channels available for given metric and region across runs.
3535
3636
Parameters
@@ -56,7 +56,7 @@ def collect_channels(scores_dict: dict, metric: str, region: str, runs) -> dict:
5656
if run_id not in run_data:
5757
continue
5858
values = run_data[run_id]["channel"].values
59-
channels.update(np.atleast_1d(values))
59+
channels.update([str(x) for x in np.atleast_1d(values)])
6060
return list(channels)
6161

6262

packages/evaluate/src/weathergen/evaluate/run_evaluation.py

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# dependencies = [
44
# "weathergen-evaluate",
55
# "weathergen-common",
6+
# "weathergen-metrics",
67
# ]
78
# [tool.uv.sources]
89
# weathergen-evaluate = { path = "../../../../../packages/evaluate" }
@@ -14,36 +15,57 @@
1415
from collections import defaultdict
1516
from pathlib import Path
1617

18+
import mlflow
19+
from mlflow.client import MlflowClient
1720
from omegaconf import OmegaConf
21+
from xarray import DataArray
1822

1923
from weathergen.common.config import _REPO_ROOT
24+
from weathergen.common.platform_env import get_platform_env
2025
from weathergen.evaluate.io_reader import WeatherGenReader
26+
from weathergen.evaluate.plot_utils import collect_channels
2127
from weathergen.evaluate.utils import (
2228
calc_scores_per_stream,
2329
metric_list_to_json,
2430
plot_data,
2531
plot_summary,
2632
retrieve_metric_from_json,
2733
)
34+
from weathergen.metrics.mlflow_utils import (
35+
MlFlowUpload,
36+
get_or_create_mlflow_parent_run,
37+
log_scores,
38+
setup_mlflow,
39+
)
2840

2941
_logger = logging.getLogger(__name__)
3042

3143
_DEFAULT_PLOT_DIR = _REPO_ROOT / "plots"
3244

45+
_platform_env = get_platform_env()
46+
3347

3448
def evaluate() -> None:
3549
# By default, arguments from the command line are read.
3650
evaluate_from_args(sys.argv[1:])
3751

3852

3953
def evaluate_from_args(argl: list[str]) -> None:
54+
# configure logging
55+
logging.basicConfig(level=logging.INFO)
4056
parser = argparse.ArgumentParser(description="Fast evaluation of WeatherGenerator runs.")
4157
parser.add_argument(
4258
"--config",
4359
type=str,
4460
default=None,
4561
help="Path to the configuration yaml file for plotting. e.g. config/plottig_config.yaml",
4662
)
63+
parser.add_argument(
64+
"--push-metrics",
65+
required=False,
66+
action="store_true",
67+
help="(optional) Upload scores to MLFlow.",
68+
)
4769

4870
args = parser.parse_args(argl)
4971
if args.config:
@@ -53,13 +75,19 @@ def evaluate_from_args(argl: list[str]) -> None:
5375
"No config file provided, using the default template config (please edit accordingly)"
5476
)
5577
config = Path(_REPO_ROOT / "config" / "evaluate" / "eval_config.yml")
56-
evaluate_from_config(OmegaConf.load(config))
78+
mlflow_client: MlflowClient | None = None
79+
if args.push_metrics:
80+
hpc_conf = _platform_env.get_hpc_config()
81+
assert hpc_conf is not None
82+
private_home = Path(hpc_conf)
83+
private_cf = OmegaConf.load(private_home)
84+
mlflow_client = setup_mlflow(private_cf)
85+
_logger.info(f"MLFlow client set up: {mlflow_client}")
5786

87+
evaluate_from_config(OmegaConf.load(config), mlflow_client)
5888

59-
def evaluate_from_config(cfg):
60-
# configure logging
61-
logging.basicConfig(level=logging.INFO)
6289

90+
def evaluate_from_config(cfg, mlflow_client: MlflowClient | None) -> None:
6391
# load configuration
6492

6593
runs = cfg.run_ids
@@ -149,6 +177,43 @@ def evaluate_from_config(cfg):
149177
{"metric": metric}
150178
)
151179

180+
if mlflow_client:
181+
# Reorder scores_dict to push to MLFlow per run_id:
182+
# Create a new defaultdict with the target structure: [run_id][metric][region][stream]
183+
reordered_dict: dict[str, dict[str, dict[str, dict[str, DataArray]]]] = defaultdict(
184+
lambda: defaultdict(lambda: defaultdict(dict))
185+
)
186+
187+
# Iterate through the original dictionary to get all keys and the final value
188+
for metric, regions_dict in scores_dict.items():
189+
for region, streams_dict in regions_dict.items():
190+
for stream, runs_dict in streams_dict.items():
191+
for run_id, final_dict in runs_dict.items():
192+
# Assign the final_dict to the new structure using the reordered keys
193+
reordered_dict[run_id][metric][region][stream] = final_dict
194+
195+
channels_set = collect_channels(scores_dict, metric, region, runs)
196+
197+
for run_id, run in runs.items():
198+
reader = WeatherGenReader(run, run_id, private_paths)
199+
from_run_id = reader.inference_cfg["from_run_id"]
200+
parent_run = get_or_create_mlflow_parent_run(mlflow_client, from_run_id)
201+
_logger.info(f"MLFlow parent run: {parent_run}")
202+
phase = "eval"
203+
with mlflow.start_run(run_id=parent_run.info.run_id):
204+
with mlflow.start_run(
205+
run_name=f"{phase}_{from_run_id}_{run_id}",
206+
parent_run_id=parent_run.info.run_id,
207+
nested=True,
208+
) as run:
209+
mlflow.set_tags(MlFlowUpload.run_tags(run_id, phase, from_run_id))
210+
log_scores(
211+
reordered_dict[run_id],
212+
mlflow_client,
213+
run.info.run_id,
214+
channels_set,
215+
)
216+
152217
# plot summary
153218
if scores_dict and cfg.evaluation.get("summary_plots", True):
154219
_logger.info("Started creating summary plots..")

packages/metrics/pyproject.toml

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
[project]
2+
name = "weathergen-metrics"
3+
version = "0.1.0"
4+
description = "The WeatherGenerator Machine Learning Earth System Model"
5+
readme = "../../README.md"
6+
requires-python = ">=3.12,<3.13"
7+
dependencies = [
8+
"mlflow-skinny",
9+
"weathergen-common",
10+
]
11+
12+
[dependency-groups]
13+
dev = [
14+
"pytest~=8.3.5",
15+
"pytest-mock>=3.14.1",
16+
"ruff==0.9.7",
17+
"pyrefly==0.36.0",
18+
]
19+
20+
21+
[tool.pyrefly]
22+
project-includes = ["src/"]
23+
project-excludes = [
24+
]
25+
26+
[tool.pyrefly.errors]
27+
bad-argument-type = false
28+
unsupported-operation = false
29+
missing-attribute = false
30+
no-matching-overload = false
31+
bad-context-manager = false
32+
33+
# To do:
34+
bad-assignment = false
35+
bad-return = false
36+
index-error = false
37+
not-iterable = false
38+
not-callable = false
39+
40+
41+
42+
43+
# The linting configuration
44+
[tool.ruff]
45+
46+
# Wide rows
47+
line-length = 100
48+
49+
[tool.ruff.lint]
50+
# All disabled until the code is formatted.
51+
select = [
52+
# pycodestyle
53+
"E",
54+
# Pyflakes
55+
"F",
56+
# pyupgrade
57+
"UP",
58+
# flake8-bugbear
59+
"B",
60+
# flake8-simplify
61+
"SIM",
62+
# isort
63+
"I",
64+
# Banned imports
65+
"TID",
66+
# Naming conventions
67+
"N",
68+
# print
69+
"T201"
70+
]
71+
72+
# These rules are sensible and should be enabled at a later stage.
73+
ignore = [
74+
# "B006",
75+
"B011",
76+
"UP008",
77+
"SIM117",
78+
"SIM118",
79+
"SIM102",
80+
"SIM401",
81+
# To ignore, not relevant for us
82+
"SIM108", # in case additional norm layer supports are added in future
83+
"N817", # we use heavy acronyms, e.g., allowing 'import LongModuleName as LMN' (LMN is accepted)
84+
"E731", # overly restrictive and less readable code
85+
"N812", # prevents us following the convention for importing torch.nn.functional as F
86+
]
87+
88+
[tool.ruff.lint.flake8-tidy-imports.banned-api]
89+
"numpy.ndarray".msg = "Do not use 'ndarray' to describe a numpy array type, it is a function. Use numpy.typing.NDArray or numpy.typing.NDArray[np.float32] for example"
90+
91+
[tool.ruff.format]
92+
# Use Unix `\n` line endings for all files
93+
line-ending = "lf"
94+
95+
96+
97+
[build-system]
98+
requires = ["hatchling"]
99+
build-backend = "hatchling.build"
100+
101+
[tool.hatch.build.targets.wheel]
102+
packages = ["src/weathergen"]

packages/metrics/src/weathergen/metrics/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)