|
7 | 7 | # granted to it by virtue of its status as an intergovernmental organisation |
8 | 8 | # nor does it submit to any jurisdiction. |
9 | 9 |
|
| 10 | +import json |
10 | 11 | import logging |
11 | 12 | import re |
12 | 13 | from dataclasses import dataclass |
13 | 14 | from pathlib import Path |
14 | 15 |
|
15 | 16 | import numpy as np |
16 | 17 | import omegaconf as oc |
| 18 | +import pandas as pd |
17 | 19 | import xarray as xr |
18 | 20 | from tqdm import tqdm |
19 | 21 |
|
@@ -85,8 +87,8 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict[str, str] | |
85 | 87 | self.eval_cfg = eval_cfg |
86 | 88 | self.run_id = run_id |
87 | 89 | self.private_paths = private_paths |
88 | | - |
89 | 90 | self.streams = eval_cfg.streams.keys() |
| 91 | + self.data = None |
90 | 92 |
|
91 | 93 | # If results_base_dir and model_base_dir are not provided, default paths are used |
92 | 94 | self.model_base_dir = self.eval_cfg.get("model_base_dir", None) |
@@ -128,6 +130,10 @@ def get_ensemble(self, stream: str | None = None) -> list[str]: |
128 | 130 | """Placeholder implementation ensemble member names getter. Override in subclass.""" |
129 | 131 | return list() |
130 | 132 |
|
| 133 | + def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray: |
| 134 | + """Placeholder to load pre-computed scores for a given run, stream, metric""" |
| 135 | + return None |
| 136 | + |
131 | 137 | def check_availability( |
132 | 138 | self, |
133 | 139 | stream: str, |
@@ -309,6 +315,146 @@ def _get_channels_fsteps_samples(self, stream: str, mode: str) -> DataAvailabili |
309 | 315 | ) |
310 | 316 |
|
311 | 317 |
|
| 318 | +##### Helper function for CSVReader #### |
| 319 | +def _rename_channels(data) -> pd.DataFrame: |
| 320 | + """ |
| 321 | + The scores downloaded from Quaver have a different convention. Need renaming. |
| 322 | + Rename channel names to include underscore between letters and digits. |
| 323 | + E.g., 'z500' -> 'z_500', 't850' -> 't_850', '2t' -> '2t', '10ff' -> '10ff' |
| 324 | +
|
| 325 | + Parameters |
| 326 | + ---------- |
| 327 | + name : str |
| 328 | + Original channel name. |
| 329 | +
|
| 330 | + Returns |
| 331 | + ------- |
| 332 | + pd.DataFrame |
| 333 | + Dataset with renamed channel names. |
| 334 | + """ |
| 335 | + for name in list(data.index): |
| 336 | + # If it starts with digits (surface vars like 2t, 10ff) → leave unchanged |
| 337 | + if re.match(r"^\d", name): |
| 338 | + continue |
| 339 | + |
| 340 | + # Otherwise, insert underscore between letters and digits |
| 341 | + data = data.rename(index={name: re.sub(r"([a-zA-Z])(\d+)", r"\1_\2", name)}) |
| 342 | + |
| 343 | + return data |
| 344 | + |
| 345 | + |
| 346 | +class CsvReader(Reader): |
| 347 | + """ |
| 348 | + Reader class to read evaluation data from CSV files and convert to xarray DataArray. |
| 349 | + """ |
| 350 | + |
| 351 | + def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): |
| 352 | + """ |
| 353 | + Initialize the CsvReader. |
| 354 | +
|
| 355 | + Parameters |
| 356 | + ---------- |
| 357 | + eval_cfg : dir |
| 358 | + config with plotting and evaluation options for that run id |
| 359 | + run_id : str |
| 360 | + run id of the model |
| 361 | + private_paths: lists |
| 362 | + list of private paths for the supported HPC |
| 363 | + """ |
| 364 | + |
| 365 | + super().__init__(eval_cfg, run_id, private_paths) |
| 366 | + self.csv_path = eval_cfg.get("csv_path") |
| 367 | + assert self.csv_path is not None, "CSV path must be provided in the config." |
| 368 | + |
| 369 | + pd_data = pd.read_csv(self.csv_path, index_col=0) |
| 370 | + |
| 371 | + self.data = _rename_channels(pd_data) |
| 372 | + self.metrics_base_dir = Path(self.csv_path).parent |
| 373 | + # for backward compatibility allow metric_dir to be specified in the run config |
| 374 | + self.metrics_dir = Path( |
| 375 | + self.eval_cfg.get("metrics_dir", self.metrics_base_dir / self.run_id / "evaluation") |
| 376 | + ) |
| 377 | + |
| 378 | + assert len(eval_cfg.streams.keys()) == 1, "CsvReader only supports one stream." |
| 379 | + self.stream = list(eval_cfg.streams.keys())[0] |
| 380 | + self.channels = self.data.index.tolist() |
| 381 | + self.samples = [0] |
| 382 | + self.forecast_steps = [int(col.split()[0]) for col in self.data.columns] |
| 383 | + self.npoints_per_sample = [0] |
| 384 | + self.epoch = eval_cfg.get("epoch", 0) |
| 385 | + self.metric = eval_cfg.get("metric") |
| 386 | + self.region = eval_cfg.get("region") |
| 387 | + |
| 388 | + def get_samples(self) -> set[int]: |
| 389 | + """get set of samples for the retrieved scores (initialisation times)""" |
| 390 | + return set(self.samples) # Placeholder implementation |
| 391 | + |
| 392 | + def get_forecast_steps(self) -> set[int]: |
| 393 | + """get set of forecast steps""" |
| 394 | + return set(self.forecast_steps) # Placeholder implementation |
| 395 | + |
| 396 | + # TODO: get this from config |
| 397 | + def get_channels(self, stream: str | None = None) -> list[str]: |
| 398 | + """get set of channels""" |
| 399 | + assert stream == self.stream, "streams do not match in CSVReader." |
| 400 | + return list(self.channels) # Placeholder implementation |
| 401 | + |
| 402 | + def get_values(self) -> xr.DataArray: |
| 403 | + """get score values in the right format""" |
| 404 | + return self.data.values[np.newaxis, :, :, np.newaxis].T |
| 405 | + |
| 406 | + def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray: |
| 407 | + """ |
| 408 | + Load the existing scores for a given run, stream and metric. |
| 409 | +
|
| 410 | + Parameters |
| 411 | + ---------- |
| 412 | + reader : |
| 413 | + Reader object containing all info for a specific run_id |
| 414 | + stream : |
| 415 | + Stream name. |
| 416 | + region : |
| 417 | + Region name. |
| 418 | + metric : |
| 419 | + Metric name. |
| 420 | +
|
| 421 | + Returns |
| 422 | + ------- |
| 423 | + xr.DataArray |
| 424 | + The metric DataArray. |
| 425 | + """ |
| 426 | + |
| 427 | + available_data = self.check_availability(stream, mode="evaluation") |
| 428 | + |
| 429 | + # fill it only for matching metric |
| 430 | + if metric == self.metric and region == self.region and stream == self.stream: |
| 431 | + data = self.get_values() |
| 432 | + else: |
| 433 | + data = np.full( |
| 434 | + ( |
| 435 | + len(available_data.samples), |
| 436 | + len(available_data.fsteps), |
| 437 | + len(available_data.channels), |
| 438 | + 1, |
| 439 | + ), |
| 440 | + np.nan, |
| 441 | + ) |
| 442 | + |
| 443 | + da = xr.DataArray( |
| 444 | + data.astype(np.float32), |
| 445 | + dims=("sample", "forecast_step", "channel", "metric"), |
| 446 | + coords={ |
| 447 | + "sample": available_data.samples, |
| 448 | + "forecast_step": available_data.fsteps, |
| 449 | + "channel": available_data.channels, |
| 450 | + "metric": [metric], |
| 451 | + }, |
| 452 | + attrs={"npoints_per_sample": self.npoints_per_sample}, |
| 453 | + ) |
| 454 | + |
| 455 | + return da |
| 456 | + |
| 457 | + |
312 | 458 | class WeatherGenReader(Reader): |
313 | 459 | def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): |
314 | 460 | """Data reader class for WeatherGenerator model outputs stored in Zarr format.""" |
@@ -656,6 +802,39 @@ def get_ensemble(self, stream: str | None = None) -> list[str]: |
656 | 802 | dummy = zio.get_data(0, stream, zio.forecast_steps[0]) |
657 | 803 | return list(dummy.prediction.as_xarray().coords["ens"].values) |
658 | 804 |
|
| 805 | + def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray | None: |
| 806 | + """ |
| 807 | + Load the pre-computed scores for a given run, stream and metric and epoch. |
| 808 | +
|
| 809 | + Parameters |
| 810 | + ---------- |
| 811 | + reader : |
| 812 | + Reader object containing all info for a specific run_id |
| 813 | + stream : |
| 814 | + Stream name. |
| 815 | + region : |
| 816 | + Region name. |
| 817 | + metric : |
| 818 | + Metric name. |
| 819 | +
|
| 820 | + Returns |
| 821 | + ------- |
| 822 | + xr.DataArray |
| 823 | + The metric DataArray or None if the file does not exist. |
| 824 | + """ |
| 825 | + score_path = ( |
| 826 | + Path(self.metrics_dir) |
| 827 | + / f"{self.run_id}_{stream}_{region}_{metric}_epoch{self.epoch:05d}.json" |
| 828 | + ) |
| 829 | + _logger.debug(f"Looking for: {score_path}") |
| 830 | + |
| 831 | + if score_path.exists(): |
| 832 | + with open(score_path) as f: |
| 833 | + data_dict = json.load(f) |
| 834 | + return xr.DataArray.from_dict(data_dict) |
| 835 | + else: |
| 836 | + return None |
| 837 | + |
659 | 838 | def get_inference_stream_attr(self, stream_name: str, key: str, default=None): |
660 | 839 | """ |
661 | 840 | Get the value of a key for a specific stream from the a model config. |
|
0 commit comments