Skip to content

Commit 4f4097c

Browse files
authored
[930][evaluation] implement CSVReader (ecmwf#932)
* first version of quaver reader * working version * add CSVReader * rebase to develop * add polimorphism * fix names * lint
1 parent 1d34631 commit 4f4097c

File tree

3 files changed

+208
-58
lines changed

3 files changed

+208
-58
lines changed

packages/evaluate/src/weathergen/evaluate/io_reader.py

Lines changed: 180 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77
# granted to it by virtue of its status as an intergovernmental organisation
88
# nor does it submit to any jurisdiction.
99

10+
import json
1011
import logging
1112
import re
1213
from dataclasses import dataclass
1314
from pathlib import Path
1415

1516
import numpy as np
1617
import omegaconf as oc
18+
import pandas as pd
1719
import xarray as xr
1820
from tqdm import tqdm
1921

@@ -85,8 +87,8 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict[str, str] |
8587
self.eval_cfg = eval_cfg
8688
self.run_id = run_id
8789
self.private_paths = private_paths
88-
8990
self.streams = eval_cfg.streams.keys()
91+
self.data = None
9092

9193
# If results_base_dir and model_base_dir are not provided, default paths are used
9294
self.model_base_dir = self.eval_cfg.get("model_base_dir", None)
@@ -128,6 +130,10 @@ def get_ensemble(self, stream: str | None = None) -> list[str]:
128130
"""Placeholder implementation ensemble member names getter. Override in subclass."""
129131
return list()
130132

133+
def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray:
134+
"""Placeholder to load pre-computed scores for a given run, stream, metric"""
135+
return None
136+
131137
def check_availability(
132138
self,
133139
stream: str,
@@ -309,6 +315,146 @@ def _get_channels_fsteps_samples(self, stream: str, mode: str) -> DataAvailabili
309315
)
310316

311317

318+
##### Helper function for CSVReader ####
319+
def _rename_channels(data) -> pd.DataFrame:
320+
"""
321+
The scores downloaded from Quaver have a different convention. Need renaming.
322+
Rename channel names to include underscore between letters and digits.
323+
E.g., 'z500' -> 'z_500', 't850' -> 't_850', '2t' -> '2t', '10ff' -> '10ff'
324+
325+
Parameters
326+
----------
327+
name : str
328+
Original channel name.
329+
330+
Returns
331+
-------
332+
pd.DataFrame
333+
Dataset with renamed channel names.
334+
"""
335+
for name in list(data.index):
336+
# If it starts with digits (surface vars like 2t, 10ff) → leave unchanged
337+
if re.match(r"^\d", name):
338+
continue
339+
340+
# Otherwise, insert underscore between letters and digits
341+
data = data.rename(index={name: re.sub(r"([a-zA-Z])(\d+)", r"\1_\2", name)})
342+
343+
return data
344+
345+
346+
class CsvReader(Reader):
347+
"""
348+
Reader class to read evaluation data from CSV files and convert to xarray DataArray.
349+
"""
350+
351+
def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None):
352+
"""
353+
Initialize the CsvReader.
354+
355+
Parameters
356+
----------
357+
eval_cfg : dir
358+
config with plotting and evaluation options for that run id
359+
run_id : str
360+
run id of the model
361+
private_paths: lists
362+
list of private paths for the supported HPC
363+
"""
364+
365+
super().__init__(eval_cfg, run_id, private_paths)
366+
self.csv_path = eval_cfg.get("csv_path")
367+
assert self.csv_path is not None, "CSV path must be provided in the config."
368+
369+
pd_data = pd.read_csv(self.csv_path, index_col=0)
370+
371+
self.data = _rename_channels(pd_data)
372+
self.metrics_base_dir = Path(self.csv_path).parent
373+
# for backward compatibility allow metric_dir to be specified in the run config
374+
self.metrics_dir = Path(
375+
self.eval_cfg.get("metrics_dir", self.metrics_base_dir / self.run_id / "evaluation")
376+
)
377+
378+
assert len(eval_cfg.streams.keys()) == 1, "CsvReader only supports one stream."
379+
self.stream = list(eval_cfg.streams.keys())[0]
380+
self.channels = self.data.index.tolist()
381+
self.samples = [0]
382+
self.forecast_steps = [int(col.split()[0]) for col in self.data.columns]
383+
self.npoints_per_sample = [0]
384+
self.epoch = eval_cfg.get("epoch", 0)
385+
self.metric = eval_cfg.get("metric")
386+
self.region = eval_cfg.get("region")
387+
388+
def get_samples(self) -> set[int]:
389+
"""get set of samples for the retrieved scores (initialisation times)"""
390+
return set(self.samples) # Placeholder implementation
391+
392+
def get_forecast_steps(self) -> set[int]:
393+
"""get set of forecast steps"""
394+
return set(self.forecast_steps) # Placeholder implementation
395+
396+
# TODO: get this from config
397+
def get_channels(self, stream: str | None = None) -> list[str]:
398+
"""get set of channels"""
399+
assert stream == self.stream, "streams do not match in CSVReader."
400+
return list(self.channels) # Placeholder implementation
401+
402+
def get_values(self) -> xr.DataArray:
403+
"""get score values in the right format"""
404+
return self.data.values[np.newaxis, :, :, np.newaxis].T
405+
406+
def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray:
407+
"""
408+
Load the existing scores for a given run, stream and metric.
409+
410+
Parameters
411+
----------
412+
reader :
413+
Reader object containing all info for a specific run_id
414+
stream :
415+
Stream name.
416+
region :
417+
Region name.
418+
metric :
419+
Metric name.
420+
421+
Returns
422+
-------
423+
xr.DataArray
424+
The metric DataArray.
425+
"""
426+
427+
available_data = self.check_availability(stream, mode="evaluation")
428+
429+
# fill it only for matching metric
430+
if metric == self.metric and region == self.region and stream == self.stream:
431+
data = self.get_values()
432+
else:
433+
data = np.full(
434+
(
435+
len(available_data.samples),
436+
len(available_data.fsteps),
437+
len(available_data.channels),
438+
1,
439+
),
440+
np.nan,
441+
)
442+
443+
da = xr.DataArray(
444+
data.astype(np.float32),
445+
dims=("sample", "forecast_step", "channel", "metric"),
446+
coords={
447+
"sample": available_data.samples,
448+
"forecast_step": available_data.fsteps,
449+
"channel": available_data.channels,
450+
"metric": [metric],
451+
},
452+
attrs={"npoints_per_sample": self.npoints_per_sample},
453+
)
454+
455+
return da
456+
457+
312458
class WeatherGenReader(Reader):
313459
def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None):
314460
"""Data reader class for WeatherGenerator model outputs stored in Zarr format."""
@@ -656,6 +802,39 @@ def get_ensemble(self, stream: str | None = None) -> list[str]:
656802
dummy = zio.get_data(0, stream, zio.forecast_steps[0])
657803
return list(dummy.prediction.as_xarray().coords["ens"].values)
658804

805+
def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray | None:
806+
"""
807+
Load the pre-computed scores for a given run, stream and metric and epoch.
808+
809+
Parameters
810+
----------
811+
reader :
812+
Reader object containing all info for a specific run_id
813+
stream :
814+
Stream name.
815+
region :
816+
Region name.
817+
metric :
818+
Metric name.
819+
820+
Returns
821+
-------
822+
xr.DataArray
823+
The metric DataArray or None if the file does not exist.
824+
"""
825+
score_path = (
826+
Path(self.metrics_dir)
827+
/ f"{self.run_id}_{stream}_{region}_{metric}_epoch{self.epoch:05d}.json"
828+
)
829+
_logger.debug(f"Looking for: {score_path}")
830+
831+
if score_path.exists():
832+
with open(score_path) as f:
833+
data_dict = json.load(f)
834+
return xr.DataArray.from_dict(data_dict)
835+
else:
836+
return None
837+
659838
def get_inference_stream_attr(self, stream_name: str, key: str, default=None):
660839
"""
661840
Get the value of a key for a specific stream from the a model config.

packages/evaluate/src/weathergen/evaluate/run_evaluation.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,13 @@
2222

2323
from weathergen.common.config import _REPO_ROOT
2424
from weathergen.common.platform_env import get_platform_env
25-
from weathergen.evaluate.io_reader import WeatherGenReader
25+
from weathergen.evaluate.io_reader import CsvReader, WeatherGenReader
2626
from weathergen.evaluate.plot_utils import collect_channels
2727
from weathergen.evaluate.utils import (
2828
calc_scores_per_stream,
2929
metric_list_to_json,
3030
plot_data,
3131
plot_summary,
32-
retrieve_metric_from_json,
3332
)
3433
from weathergen.metrics.mlflow_utils import (
3534
MlFlowUpload,
@@ -111,7 +110,13 @@ def evaluate_from_config(cfg, mlflow_client: MlflowClient | None) -> None:
111110
for run_id, run in runs.items():
112111
_logger.info(f"RUN {run_id}: Getting data...")
113112

114-
reader = WeatherGenReader(run, run_id, private_paths)
113+
type = run.get("type", "zarr")
114+
if type == "zarr":
115+
reader = WeatherGenReader(run, run_id, private_paths)
116+
elif type == "csv":
117+
reader = CsvReader(run, run_id, private_paths)
118+
else:
119+
raise ValueError(f"Unknown run type {type} for run {run_id}. Supported: zarr, csv.")
115120

116121
for stream in reader.streams:
117122
_logger.info(f"RUN {run_id}: Processing stream {stream}...")
@@ -135,29 +140,29 @@ def evaluate_from_config(cfg, mlflow_client: MlflowClient | None) -> None:
135140
metrics_to_compute = []
136141

137142
for metric in metrics:
138-
try:
139-
metric_data = retrieve_metric_from_json(
140-
reader,
141-
stream,
142-
region,
143-
metric,
144-
)
143+
metric_data = reader.load_scores(
144+
stream,
145+
region,
146+
metric,
147+
)
145148

146-
available_data = reader.check_availability(
147-
stream, metric_data, mode="evaluation"
148-
)
149+
if metric_data is None:
150+
metrics_to_compute.append(metric)
151+
continue
152+
153+
available_data = reader.check_availability(
154+
stream, metric_data, mode="evaluation"
155+
)
149156

150-
if not available_data.score_availability:
151-
metrics_to_compute.append(metric)
152-
else:
153-
# simply select the chosen eval channels, samples, fsteps here...
154-
scores_dict[metric][region][stream][run_id] = metric_data.sel(
155-
sample=available_data.samples,
156-
channel=available_data.channels,
157-
forecast_step=available_data.fsteps,
158-
)
159-
except (FileNotFoundError, KeyError):
157+
if not available_data.score_availability:
160158
metrics_to_compute.append(metric)
159+
else:
160+
# simply select the chosen eval channels, samples, fsteps here...
161+
scores_dict[metric][region][stream][run_id] = metric_data.sel(
162+
sample=available_data.samples,
163+
channel=available_data.channels,
164+
forecast_step=available_data.fsteps,
165+
)
161166

162167
if metrics_to_compute:
163168
all_metrics, points_per_sample = calc_scores_per_stream(

packages/evaluate/src/weathergen/evaluate/utils.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -375,40 +375,6 @@ def metric_list_to_json(
375375
)
376376

377377

378-
def retrieve_metric_from_json(reader: Reader, stream: str, region: str, metric: str):
379-
"""
380-
Retrieve the score for a given run, stream, metric, epoch, and rank from a JSON file.
381-
382-
Parameters
383-
----------
384-
reader :
385-
Reader object containing all info for a specific run_id
386-
stream :
387-
Stream name.
388-
region :
389-
Region name.
390-
metric :
391-
Metric name.
392-
393-
Returns
394-
-------
395-
xr.DataArray
396-
The metric DataArray.
397-
"""
398-
score_path = (
399-
Path(reader.metrics_dir)
400-
/ f"{reader.run_id}_{stream}_{region}_{metric}_epoch{reader.epoch:05d}.json"
401-
)
402-
_logger.debug(f"Looking for: {score_path}")
403-
404-
if score_path.exists():
405-
with open(score_path) as f:
406-
data_dict = json.load(f)
407-
return xr.DataArray.from_dict(data_dict)
408-
else:
409-
raise FileNotFoundError(f"File {score_path} not found in the archive.")
410-
411-
412378
def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path):
413379
"""
414380
Plot summary of the evaluation results.

0 commit comments

Comments
 (0)