Skip to content

Commit d6d7f15

Browse files
authored
feat(graphs): New Edge Attribute: AttributeFromNode (ecmwf#62)
* Implemented new attribute * changelog update * changelog update * Refactored following review * Update src/anemoi/graphs/edges/attributes.py Co-authored-by: Mario Santa Cruz <48736305+JPXKQX@users.noreply.github.com> * Update src/anemoi/graphs/edges/attributes.py Co-authored-by: Mario Santa Cruz <48736305+JPXKQX@users.noreply.github.com> * Update src/anemoi/graphs/edges/attributes.py Co-authored-by: Mario Santa Cruz <48736305+JPXKQX@users.noreply.github.com> * Update src/anemoi/graphs/edges/attributes.py Co-authored-by: Mario Santa Cruz <48736305+JPXKQX@users.noreply.github.com> * Refactor after review * Docstring done * Docstring done * Docstring done * Fixed ABC issue * addressed changes in docs and exception error * Changed changelog * TMP * Addressed comments * gpc * Removed update to changelog * Added test for fail copy * Gpc passed * Fixed test * Fixed test * Test working now
1 parent 4690ed5 commit d6d7f15

File tree

3 files changed

+148
-5
lines changed

3 files changed

+148
-5
lines changed

graphs/docs/graphs/edge_attributes.rst

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
Edges - Attributes
55
####################
66

7-
There are two main edge attributes implemented in the
8-
:ref:`anemoi-graphs <anemoi-graphs:index-page>` package:
7+
There are few edge attributes implemented in the `anemoi-graphs`
8+
package:
99

1010
*************
1111
Edge length
@@ -44,3 +44,49 @@ latitude and longitude coordinates of the source and target nodes.
4444
attributes:
4545
edge_length:
4646
_target_: anemoi.graphs.edges.attributes.EdgeDirection
47+
48+
*********************
49+
Attribute from Node
50+
*********************
51+
52+
Attributes can also be copied from nodes to edges. This is done using
53+
the `AttributeFromNode` base class, with specialized versions for source
54+
and target nodes.
55+
56+
From Source
57+
===========
58+
59+
This attribute copies a specific property of the source node to the
60+
edge. Example usage for copying the cutout mask from nodes to edges in
61+
the encoder:
62+
63+
.. code:: yaml
64+
65+
edges:
66+
# Encoder
67+
- source_name: data
68+
target_name: hidden
69+
edge_builders: ...
70+
attributes:
71+
comes_from_cutout: # Assigned name to the edge attribute, can be different than node_attr_name
72+
_target_: anemoi.graphs.edges.attributes.AttributeFromSourceNode
73+
node_attr_name: cutout
74+
75+
From Target
76+
===========
77+
78+
This attribute copies a specific property of the target node to the
79+
edge. Example usage for copying the coutout mask from nodes to edges in
80+
the decoder:
81+
82+
.. code:: yaml
83+
84+
edges:
85+
# Decoder
86+
- source_name: hidden
87+
target_name: data
88+
edge_builders: ...
89+
attributes:
90+
comes_from_cutout: # Assigned name to the edge attribute, can be different than node_attr_name
91+
_target_: anemoi.graphs.edges.attributes.AttributeFromTargetNode
92+
node_attr_name: cutout

graphs/src/anemoi/graphs/edges/attributes.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@
2424
class BaseEdgeAttribute(ABC, NormaliserMixin):
2525
"""Base class for edge attributes."""
2626

27-
def __init__(self, norm: str | None = None) -> None:
27+
def __init__(self, norm: str | None = None, dtype: str = "float32") -> None:
2828
self.norm = norm
29+
self.dtype = dtype
2930

3031
@abstractmethod
3132
def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str, *args, **kwargs) -> np.ndarray: ...
@@ -35,9 +36,9 @@ def post_process(self, values: np.ndarray) -> torch.Tensor:
3536
if values.ndim == 1:
3637
values = values[:, np.newaxis]
3738

38-
normed_values = self.normalise(values)
39+
normalised_values = self.normalise(values)
3940

40-
return torch.tensor(normed_values, dtype=torch.float32)
41+
return torch.tensor(normalised_values.astype(self.dtype))
4142

4243
def compute(self, graph: HeteroData, edges_name: tuple[str, str, str], *args, **kwargs) -> torch.Tensor:
4344
"""Compute the edge attributes."""
@@ -155,3 +156,81 @@ def post_process(self, values: np.ndarray) -> torch.Tensor:
155156
values = 1 - values
156157

157158
return values
159+
160+
161+
class BooleanBaseEdgeAttribute(BaseEdgeAttribute, ABC):
162+
"""Base class for boolean edge attributes."""
163+
164+
def __init__(self) -> None:
165+
super().__init__(norm=None, dtype="bool")
166+
167+
168+
class BaseAttributeFromNode(BooleanBaseEdgeAttribute, ABC):
169+
"""
170+
Base class for Attribute from Node.
171+
172+
Copy an attribute of either the source or target node to the edge.
173+
Accesses source/target node attribute and propagates it to the edge.
174+
Used for example to identify if an encoder edge originates from a LAM or global node.
175+
176+
Attributes
177+
----------
178+
node_attr_name : str
179+
Name of the node attribute to propagate.
180+
181+
Methods
182+
-------
183+
get_node_name(source_name, target_name)
184+
Return the name of the node to copy.
185+
186+
get_raw_values(graph, source_name, target_name)
187+
Computes the edge attribute from the source or target node attribute.
188+
189+
"""
190+
191+
def __init__(self, node_attr_name: str) -> None:
192+
super().__init__()
193+
self.node_attr_name = node_attr_name
194+
self.idx = None
195+
196+
@abstractmethod
197+
def get_node_name(self, source_name: str, target_name: str): ...
198+
199+
def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str) -> np.ndarray:
200+
201+
node_name = self.get_node_name(source_name, target_name)
202+
203+
edge_index = graph[(source_name, "to", target_name)].edge_index
204+
try:
205+
return graph[node_name][self.node_attr_name].numpy()[edge_index[self.idx]]
206+
207+
except KeyError:
208+
raise KeyError(
209+
f"{self.__class__.__name__} failed because the attribute '{self.node_attr_name}' is not defined for the nodes."
210+
)
211+
212+
213+
class AttributeFromSourceNode(BaseAttributeFromNode):
214+
"""
215+
Copy an attribute of the source node to the edge.
216+
"""
217+
218+
def __init__(self, node_attr_name: str) -> None:
219+
super().__init__(node_attr_name)
220+
self.idx = 0
221+
222+
def get_node_name(self, source_name: str, target_name: str):
223+
return source_name
224+
225+
226+
class AttributeFromTargetNode(BaseAttributeFromNode):
227+
"""
228+
Copy an attribute of the target node to the edge.
229+
"""
230+
231+
def __init__(self, node_attr_name: str) -> None:
232+
super().__init__(node_attr_name)
233+
self.idx = 1
234+
235+
def get_node_name(self, source_name: str, target_name: str):
236+
return target_name

graphs/tests/edges/test_edge_attributes.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@
77
# granted to it by virtue of its status as an intergovernmental organisation
88
# nor does it submit to any jurisdiction.
99

10+
from functools import partial
11+
1012
import pytest
1113
import torch
1214

15+
from anemoi.graphs.edges.attributes import AttributeFromSourceNode
16+
from anemoi.graphs.edges.attributes import AttributeFromTargetNode
1317
from anemoi.graphs.edges.attributes import EdgeDirection
1418
from anemoi.graphs.edges.attributes import EdgeLength
1519

@@ -36,3 +40,17 @@ def test_fail_edge_features(attribute_builder, graph_nodes_and_edges):
3640
"""Test edge attribute builder fails with unknown nodes."""
3741
with pytest.raises(AssertionError):
3842
attribute_builder.compute(graph_nodes_and_edges, ("test_nodes", "to", "unknown_nodes"))
43+
44+
45+
@pytest.mark.parametrize(
46+
"attribute_builder",
47+
[
48+
partial(AttributeFromSourceNode, node_attr_name="example_attr"),
49+
partial(AttributeFromTargetNode, node_attr_name="example_attr"),
50+
],
51+
)
52+
def test_fail_edge_attribute_from_node(attribute_builder, graph_nodes_and_edges):
53+
"""Test edge attribute builder fails with unknown nodes."""
54+
with pytest.raises(KeyError):
55+
builder_instance = attribute_builder()
56+
builder_instance.compute(graph_nodes_and_edges, ("test_nodes", "to", "test_nodes"))

0 commit comments

Comments
 (0)