Skip to content

Commit 2d4a54d

Browse files
authored
fix(configs,schemas): hierarchical schemas (ecmwf#221)
1 parent f1d5e1f commit 2d4a54d

File tree

4 files changed

+28
-12
lines changed

4 files changed

+28
-12
lines changed

training/src/anemoi/training/config/graph/hierarchical_2level.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ edges:
4040
edge_builders:
4141
- _target_: anemoi.graphs.edges.KNNEdges
4242
num_nearest_neighbours: 3
43+
source_mask_attr_name: null
44+
target_mask_attr_name: null
4345
attributes: ${graph.attributes.edges}
4446

4547
# Hierarchical connections: downscale

training/src/anemoi/training/config/graph/hierarchical_3level.yaml

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,55 +32,66 @@ edges:
3232
# Encoder configuration
3333
- source_name: "data"
3434
target_name: "hidden_1"
35-
edge_builder:
36-
_target_: anemoi.graphs.edges.CutOffEdges
35+
edge_builders:
36+
- _target_: anemoi.graphs.edges.CutOffEdges
3737
cutoff_factor: 0.6
38+
source_mask_attr_name: null
39+
target_mask_attr_name: null
3840
attributes: ${graph.attributes.edges}
3941

4042
# Decoder configuration
4143
- source_name: "hidden_1"
4244
target_name: "data"
43-
edge_builder:
44-
_target_: anemoi.graphs.edges.KNNEdges
45+
edge_builders:
46+
- _target_: anemoi.graphs.edges.KNNEdges
4547
num_nearest_neighbours: 3
48+
source_mask_attr_name: null
49+
target_mask_attr_name: null
4650
attributes: ${graph.attributes.edges}
4751

4852
# Hierarchical connections: downscale
4953
- source_name: "hidden_1"
5054
target_name: "hidden_2"
51-
edge_builder: ${graph.edge_builders.downscale}
55+
edge_builders:
56+
- ${graph.edge_builders.downscale}
5257
attributes: ${graph.attributes.edges}
5358

5459
- source_name: "hidden_2"
5560
target_name: "hidden_3"
56-
edge_builder: ${graph.edge_builders.downscale}
61+
edge_builders:
62+
- ${graph.edge_builders.downscale}
5763
attributes: ${graph.attributes.edges}
5864

5965
# Hierarchical connections: upscale
6066
- source_name: "hidden_3"
6167
target_name: "hidden_2"
62-
edge_builder: ${graph.edge_builders.upscale}
68+
edge_builders:
69+
- ${graph.edge_builders.upscale}
6370
attributes: ${graph.attributes.edges}
6471

6572
- source_name: "hidden_2"
6673
target_name: "hidden_1"
67-
edge_builder: ${graph.edge_builders.upscale}
74+
edge_builders:
75+
- ${graph.edge_builders.upscale}
6876
attributes: ${graph.attributes.edges}
6977

7078
# Hierarchical connections: same level
7179
- source_name: "hidden_1"
7280
target_name: "hidden_1"
73-
edge_builder: ${graph.edge_builders.process}
81+
edge_builders:
82+
- ${graph.edge_builders.process}
7483
attributes: ${graph.attributes.edges}
7584

7685
- source_name: "hidden_2"
7786
target_name: "hidden_2"
78-
edge_builder: ${graph.edge_builders.process}
87+
edge_builders:
88+
- ${graph.edge_builders.process}
7989
attributes: ${graph.attributes.edges}
8090

8191
- source_name: "hidden_3"
8292
target_name: "hidden_3"
83-
edge_builder: ${graph.edge_builders.process}
93+
edge_builders:
94+
- ${graph.edge_builders.process}
8495
attributes: ${graph.attributes.edges}
8596

8697

training/src/anemoi/training/schemas/graphs/base_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class BaseGraphSchema(PydanticBaseModel):
5555
post_processors: list[ProcessorSchemas] = Field(default_factory=list)
5656
data: str = Field(example="data")
5757
"Key name for the data nodes. Default to 'data'."
58-
hidden: str = Field(example="hidden")
58+
hidden: Union[str, list[str]] = Field(example="hidden")
5959
"Key name for the hidden nodes. Default to 'hidden'."
6060
# TODO(Helen): Needs to be adjusted for more complex graph setups
6161

training/src/anemoi/training/schemas/models/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,11 @@
3535

3636
class DefinedModels(str, Enum):
3737
ANEMOI_MODEL_ENC_PROC_DEC = "anemoi.models.models.encoder_processor_decoder.AnemoiModelEncProcDec"
38+
ANEMOI_MODEL_ENC_PROC_DEC_SHORT = "anemoi.models.models.AnemoiModelEncProcDec"
3839
ANEMOI_MODEL_ENC_HIERPROC_DEC = "anemoi.models.models.hierarchical.AnemoiModelEncProcDecHierarchical"
40+
ANEMOI_MODEL_ENC_HIERPROC_DEC_SHORT = "anemoi.models.models.AnemoiModelEncProcDecHierarchical"
3941
ANEMOI_MODEL_INTERPENC_PROC_DEC = "anemoi.models.models.interpolator.AnemoiModelEncProcDecInterpolator"
42+
ANEMOI_MODEL_INTERPENC_PROC_DEC_SHORT = "anemoi.models.models.AnemoiModelEncProcDecInterpolator"
4043

4144

4245
class Model(BaseModel):

0 commit comments

Comments
 (0)