Skip to content

Commit 089de3b

Browse files
authored
Implement regional evaluation (ecmwf#652)
* Add RegionBoundingBox data class to score-utils to handle evaluation for different regions. * Implement region-specific evaluation in plot_inference.py. * Adapted utils. * Introduction of clean RegionLibrary in score_utils.py. * Ruffed code. * Updates following reviewer comments. * Ruffed code.
1 parent e2cc583 commit 089de3b

File tree

3 files changed

+270
-110
lines changed

3 files changed

+270
-110
lines changed

packages/evaluate/src/weathergen/evaluate/plot_inference.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,10 @@ def run_main(cfg: DictConfig) -> None:
3838
out_scores_dir.mkdir(parents=True, exist_ok=True)
3939

4040
metrics = cfg.evaluation.metrics
41+
regions = cfg.evaluation.get("regions", ["global"])
4142

42-
# to get a structure like: scores_dict[metric][stream][run_id] = plot
43-
scores_dict = defaultdict(lambda: defaultdict(dict))
43+
# to get a structure like: scores_dict[metric][region][stream][run_id] = plot
44+
scores_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
4445

4546
for run_id, run in runs.items():
4647
_logger.info(f"RUN {run_id}: Getting data...")
@@ -59,34 +60,37 @@ def run_main(cfg: DictConfig) -> None:
5960
if stream_dict.get("evaluation"):
6061
_logger.info(f"Retrieve or compute scores for {run_id} - {stream}...")
6162

62-
metrics_to_compute = []
63-
for metric in metrics:
64-
try:
65-
metric_data = retrieve_metric_from_json(
63+
for region in regions:
64+
metrics_to_compute = []
65+
66+
for metric in metrics:
67+
try:
68+
metric_data = retrieve_metric_from_json(
69+
out_scores_dir,
70+
run_id,
71+
stream,
72+
region,
73+
metric,
74+
run.epoch,
75+
)
76+
scores_dict[metric][region][stream][run_id] = metric_data
77+
except (FileNotFoundError, KeyError, ValueError):
78+
metrics_to_compute.append(metric)
79+
80+
if metrics_to_compute:
81+
all_metrics, points_per_sample = calc_scores_per_stream(
82+
cfg, run_id, stream, region, metrics_to_compute
83+
)
84+
85+
metric_list_to_json(
86+
[all_metrics],
87+
[points_per_sample],
88+
[stream],
89+
region,
6690
out_scores_dir,
6791
run_id,
68-
stream,
69-
metric,
7092
run.epoch,
71-
run.rank,
7293
)
73-
scores_dict[metric][stream][run_id] = metric_data
74-
except (FileNotFoundError, KeyError, ValueError):
75-
metrics_to_compute.append(metric)
76-
if metrics_to_compute:
77-
all_metrics, points_per_sample = calc_scores_per_stream(
78-
cfg, run_id, stream, metrics_to_compute
79-
)
80-
81-
metric_list_to_json(
82-
[all_metrics],
83-
[points_per_sample],
84-
[stream],
85-
out_scores_dir,
86-
run_id,
87-
run.epoch,
88-
run.rank,
89-
)
9094

9195
for metric in metrics_to_compute:
9296
scores_dict[metric][stream][run_id] = all_metrics.sel(

packages/evaluate/src/weathergen/evaluate/score_utils.py

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,16 @@
77
# granted to it by virtue of its status as an intergovernmental organisation
88
# nor does it submit to any jurisdiction.
99

10-
from typing import Any
10+
import logging
11+
from dataclasses import dataclass
12+
from typing import Any, ClassVar
1113

14+
import xarray as xr
1215
from omegaconf.listconfig import ListConfig
1316

17+
_logger = logging.getLogger(__name__)
18+
_logger.setLevel(logging.INFO)
19+
1420

1521
def to_list(obj: Any) -> list:
1622
"""
@@ -30,3 +36,100 @@ def to_list(obj: Any) -> list:
3036
elif not isinstance(obj, list):
3137
obj = [obj]
3238
return obj
39+
40+
41+
class RegionLibrary:
42+
"""
43+
Predefined bounding boxes for known regions.
44+
"""
45+
46+
REGIONS: ClassVar[dict[str, tuple[float, float, float, float]]] = {
47+
"global": (-90.0, 90.0, -180.0, 180.0),
48+
"nhem": (0.0, 90.0, -180.0, 180.0),
49+
"shem": (-90.0, 0.0, -180.0, 180.0),
50+
"tropics": (-30.0, 30.0, -180.0, 180.0),
51+
}
52+
53+
54+
@dataclass(frozen=True)
55+
class RegionBoundingBox:
56+
lat_min: float
57+
lat_max: float
58+
lon_min: float
59+
lon_max: float
60+
61+
def __post_init__(self):
62+
"""Validate the bounding box coordinates."""
63+
self.validate()
64+
65+
def validate(self):
66+
"""Validate the bounding box coordinates."""
67+
if not (-90 <= self.lat_min <= 90 and -90 <= self.lat_max <= 90):
68+
raise ValueError(
69+
f"Latitude bounds must be between -90 and 90. Got: {self.lat_min}, {self.lat_max}"
70+
)
71+
if not (-180 <= self.lon_min <= 180 and -180 <= self.lon_max <= 180):
72+
raise ValueError(
73+
f"Longitude bounds must be between -180 and 180. Got: {self.lon_min}, {self.lon_max}"
74+
)
75+
if self.lat_min >= self.lat_max:
76+
raise ValueError(
77+
f"Latitude minimum must be less than maximum. Got: {self.lat_min}, {self.lat_max}"
78+
)
79+
if self.lon_min >= self.lon_max:
80+
raise ValueError(
81+
f"Longitude minimum must be less than maximum. Got: {self.lon_min}, {self.lon_max}"
82+
)
83+
84+
def contains(self, lat: float, lon: float) -> bool:
85+
"""Check if a lat/lon point is within the bounding box."""
86+
return (self.lat_min <= lat <= self.lat_max) and (
87+
self.lon_min <= lon <= self.lon_max
88+
)
89+
90+
def apply_mask(
91+
self,
92+
data: xr.Dataset | xr.DataArray,
93+
lat_name: str = "lat",
94+
lon_name: str = "lon",
95+
data_dim: str = "ipoint",
96+
) -> xr.Dataset | xr.DataArray:
97+
"""Filter Dataset or DataArray by spatial bounding box on 'ipoint' dimension.
98+
Parameters
99+
----------
100+
data :
101+
The data to filter.
102+
lat_name:
103+
Name of the latitude coordinate in the data.
104+
lon_name:
105+
Name of the longitude coordinate in the data.
106+
data_dim:
107+
Name of the dimension that contains the lat/lon coordinates.
108+
109+
Returns
110+
-------
111+
Filtered data with only points within the bounding box.
112+
"""
113+
# lat/lon coordinates should be 1D and aligned with ipoint
114+
lat = data[lat_name]
115+
lon = data[lon_name]
116+
117+
mask = (
118+
(lat >= self.lat_min)
119+
& (lat <= self.lat_max)
120+
& (lon >= self.lon_min)
121+
& (lon <= self.lon_max)
122+
)
123+
124+
return data.sel({data_dim: mask})
125+
126+
@classmethod
127+
def from_region_name(cls, region: str) -> "RegionBoundingBox":
128+
region = region.lower()
129+
try:
130+
return cls(*RegionLibrary.REGIONS[region])
131+
except KeyError as err:
132+
raise ValueError(
133+
f"Region '{region}' is not supported. "
134+
f"Available regions: {', '.join(RegionLibrary.REGIONS.keys())}"
135+
) from err

0 commit comments

Comments
 (0)