Skip to content

Commit eb7afd1

Browse files
authored
fix: error combining 3 or plus inputs in cutout (ecmwf#256)
## Description ecmwf#249 introduced a bug when there is more than 2 sources. Indeed, when `combined_mask` is None, it adds an extra dimension to the state. I also added a test. ## What problem does this change solve? <!-- Describe if it's a bugfix, new feature, doc update, or breaking change --> ## What issue or task does this change relate to? <!-- link to Issue Number --> ## Additional notes ## <!-- Include any additional information, caveats, or considerations that the reviewer should be aware of. --> ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/***
1 parent 74c8e20 commit eb7afd1

File tree

2 files changed

+37
-3
lines changed

2 files changed

+37
-3
lines changed

src/anemoi/inference/inputs/cutout.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import Iterable
1313
from typing import List
1414
from typing import Optional
15+
from typing import Union
1516

1617
import numpy as np
1718

@@ -28,7 +29,7 @@
2829
def _mask_and_combine_states(
2930
combined_state: State,
3031
new_state: State,
31-
combined_mask: Optional[np.ndarray],
32+
combined_mask: Union[np.ndarray, slice],
3233
mask: np.ndarray,
3334
fields: Iterable[str],
3435
) -> State:
@@ -119,7 +120,7 @@ def create_input_state(self, *, date: Optional[Date]) -> State:
119120
combined_state["fields"] = _mask_and_combine_states(
120121
combined_state["fields"], new_state["fields"], combined_mask, mask, combined_state["fields"]
121122
)
122-
combined_mask = None
123+
combined_mask = slice(0, None)
123124

124125
return combined_state
125126

@@ -154,7 +155,7 @@ def load_forcings_state(self, *, variables: List[str], dates: List[Date], curren
154155
combined_fields = _mask_and_combine_states(
155156
combined_fields, new_fields, combined_mask, mask, combined_fields
156157
)
157-
combined_mask = None
158+
combined_mask = slice(0, None)
158159

159160
current_state["fields"] |= combined_fields
160161
return current_state

tests/unit/test_input_cutout.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import numpy as np
2+
3+
from anemoi.inference.inputs.cutout import _mask_and_combine_states
4+
5+
6+
def test_mask_and_combine_states():
7+
states = [{"a": np.arange(5).astype(float)}, {"a": np.arange(5, 10).astype(float)}, {"a": np.arange(10, 15)}]
8+
masks = [np.zeros(5).astype(bool) for _ in states]
9+
masks[0][[1, 2]] = True
10+
masks[1][[2, 3]] = True
11+
masks[2][[2, 4]] = True
12+
combined_state = {k: states[0][k][:] for k in states[0]}
13+
combined_mask = masks[0]
14+
for k in range(1, len(states)):
15+
mask = masks[k]
16+
new_state = states[k]
17+
combined_state = _mask_and_combine_states(combined_state, new_state, combined_mask, mask, ["a"])
18+
combined_mask = slice(0, None)
19+
20+
assert combined_state["a"].shape[0] == 6
21+
assert (
22+
combined_state["a"]
23+
== np.array(
24+
[
25+
states[0]["a"][1],
26+
states[0]["a"][2],
27+
states[1]["a"][2],
28+
states[1]["a"][3],
29+
states[2]["a"][2],
30+
states[2]["a"][4],
31+
]
32+
)
33+
).all()

0 commit comments

Comments
 (0)