Skip to content

Commit 75a5e3b

Browse files
authored
feat: Mars with Nested models (ecmwf#236)
## Description The mars input and by association other inputs cannot be used with nested models. See ecmwf#235 for context. In addition to the changes in this PR, some extra manual patches of the checkpoint are required. This enables an input of ```yaml cutout: lam_0: mars: grid: '0.01/0.01' area: '2/-2/-2/2' mask: 'source0/trimedge_mask' global: 'mars' ```
1 parent d6863a0 commit 75a5e3b

File tree

4 files changed

+52
-14
lines changed

4 files changed

+52
-14
lines changed

src/anemoi/inference/checkpoint.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -612,13 +612,18 @@ def mars_requests(
612612
if use_grib_paramid and "param" in r:
613613
r["param"] = [shortname_to_paramid(_) for _ in r["param"]]
614614

615-
# Simplyfie the request
616-
617-
for k, v in r.items():
615+
# Simplify the request
618616

617+
for k in list(r.keys()):
618+
v = r[k]
619619
if len(v) == 1:
620-
r[k] = v[0]
620+
v = v[0]
621621

622+
# Remove empty values for when tree is not fully defined
623+
if v == "-":
624+
r.pop(k)
625+
continue
626+
r[k] = v
622627
result.append(r)
623628

624629
return result

src/anemoi/inference/inputs/cutout.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,36 @@
2424
LOG = logging.getLogger(__name__)
2525

2626

27+
def _mask_and_nest_state(
28+
state: State,
29+
_state: State,
30+
mask: np.ndarray,
31+
) -> State:
32+
"""Mask and nest the state with the given mask.
33+
34+
Parameters
35+
----------
36+
state : State
37+
The state to be masked and nested.
38+
_state : State
39+
The state to be masked and nested with.
40+
mask : np.ndarray
41+
The mask to be applied.
42+
43+
Returns
44+
-------
45+
State
46+
The masked and nested state.
47+
"""
48+
for field, values in state["fields"].items():
49+
state["fields"][field] = np.concatenate([values, _state["fields"][field][..., mask]], axis=-1)
50+
51+
state["latitudes"] = np.concatenate([state["latitudes"], _state["latitudes"][..., mask]], axis=-1)
52+
state["longitudes"] = np.concatenate([state["longitudes"], _state["longitudes"][..., mask]], axis=-1)
53+
54+
return state
55+
56+
2757
@input_registry.register("cutout")
2858
class Cutout(Input):
2959
"""Combines one or more LAMs into a global source using cutouts."""
@@ -43,7 +73,10 @@ def __init__(self, context, **sources: dict[str, dict]):
4373
self.sources: dict[str, Input] = {}
4474
self.masks: dict[str, np.ndarray] = {}
4575
for src, cfg in sources.items():
46-
mask = cfg.pop("mask", f"{src}/cutout_mask")
76+
if isinstance(cfg, str):
77+
mask = f"{src}/cutout_mask"
78+
else:
79+
mask = cfg.pop("mask", f"{src}/cutout_mask")
4780
self.sources[src] = create_input(context, cfg)
4881
self.masks[src] = self.sources[src].checkpoint.load_supporting_array(mask)
4982

@@ -72,10 +105,7 @@ def create_input_state(self, *, date: Optional[Date]) -> State:
72105
mask = self.masks[source]
73106
_state = self.sources[source].create_input_state(date=date)
74107

75-
state["latitudes"] = np.concatenate([state["latitudes"], _state["latitudes"][..., mask]], axis=-1)
76-
state["longitudes"] = np.concatenate([state["longitudes"], _state["longitudes"][..., mask]], axis=-1)
77-
for field, values in state["fields"].items():
78-
state["fields"][field] = np.concatenate([values, _state["fields"][field][..., mask]], axis=-1)
108+
state = _mask_and_nest_state(state, _state, mask)
79109

80110
return state
81111

src/anemoi/inference/inputs/mars.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ def postproc(
124124
if grid_is_valid(grid):
125125
pproc["grid"] = grid
126126

127+
if isinstance(area, str):
128+
area = [float(x) for x in area.split("/")]
129+
127130
if area_is_valid(area):
128131
pproc["area"] = rounded_area(area)
129132

@@ -219,7 +222,7 @@ def __init__(
219222
patches : Optional[List[Tuple[Dict[str, Any], Dict[str, Any]]]]
220223
Optional list of patches for the input.
221224
**kwargs : Any
222-
Additional keyword arguments.
225+
Additional keyword to pass to the request to MARS.
223226
"""
224227
super().__init__(context, namer=namer)
225228
self.kwargs = kwargs
@@ -282,12 +285,12 @@ def retrieve(self, variables: List[str], dates: List[Date]) -> Any:
282285

283286
kwargs = self.kwargs.copy()
284287
kwargs.setdefault("expver", "0001")
288+
kwargs.setdefault("grid", self.checkpoint.grid)
289+
kwargs.setdefault("area", self.checkpoint.area)
285290

286291
return retrieve(
287292
requests,
288-
self.checkpoint.grid,
289-
self.checkpoint.area,
290-
self.patch,
293+
patch=self.patch,
291294
**kwargs,
292295
)
293296

src/anemoi/inference/metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ def mars_requests(self, *, variables: List[str]) -> Iterator[DataRequest]:
596596

597597
mars = self.variables_metadata[variable]["mars"].copy()
598598

599-
for k in ("date", "hdate", "time"):
599+
for k in ("date", "hdate", "time", "valid_datetime", "variable"):
600600
mars.pop(k, None)
601601

602602
yield mars

0 commit comments

Comments
 (0)