Skip to content

Commit dee6518

Browse files
committed
FuncEdge and TransformEdge
1 parent 623aa74 commit dee6518

File tree

1 file changed

+67
-2
lines changed

1 file changed

+67
-2
lines changed

data_prototype/conversion_edge.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
from collections.abc import Sequence
2+
from typing import Callable
23
from dataclasses import dataclass
34
from typing import Any
5+
import numpy as np
46

5-
from .containers import Desc
7+
from data_prototype.containers import Desc
8+
9+
from matplotlib.transforms import Transform
610

711

812
@dataclass
@@ -42,6 +46,66 @@ def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
4246
return {k: input[k] for k in self.output}
4347

4448

49+
@dataclass
50+
class FuncEdge(Edge):
51+
# TODO: more explicit callable boundaries?
52+
func: Callable = lambda: {}
53+
54+
@classmethod
55+
def from_func(
56+
cls,
57+
name: str,
58+
func: Callable,
59+
input: str | dict[str, Desc],
60+
output: str | dict[str, Desc],
61+
):
62+
# dtype/shape is reductive here, but I like the idea of being able to just
63+
# supply a function and the input/output coordinates for many things
64+
if isinstance(input, str):
65+
import inspect
66+
67+
input_vars = inspect.signature(func).parameters.keys()
68+
input = {k: Desc(("N",), np.dtype("f8"), input) for k in input_vars}
69+
if isinstance(output, str):
70+
output = {k: Desc(("N",), np.dtype("f8"), output) for k in input.keys()}
71+
72+
return cls(name, input, output, False, func)
73+
74+
def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
75+
res = self.func(**{k: input[k] for k in self.input})
76+
77+
if isinstance(res, dict):
78+
# TODO: more sanity checks here?
79+
# How forgiving do we _really_ wish to be?
80+
return res
81+
elif isinstance(res, tuple):
82+
if len(res) != len(self.output):
83+
raise RuntimeError(
84+
f"Expected {len(self.output)} return values, got {len(res)}"
85+
)
86+
return {k: v for k, v in zip(self.output, res)}
87+
elif len(self.output) == 1:
88+
return {k: res for k in self.output}
89+
raise RuntimeError("Output of function does not match expected output")
90+
91+
92+
@dataclass
93+
class TransformEdge(Edge):
94+
transform: Transform | None = None
95+
96+
# TODO: helper for common cases/validation?
97+
98+
def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
99+
# TODO: ensure ordering?
100+
# Stacking and unstacking at every step seems inefficient,
101+
# especially if initially given as stacked
102+
if self.transform is None:
103+
return input
104+
inp = np.stack([input[k] for k in self.input], axis=-1)
105+
outp = self.transform.transform(inp)
106+
return {k: v for k, v in zip(self.output, outp.T)}
107+
108+
45109
class Graph:
46110
def __init__(self, edges: Sequence[Edge]):
47111
self._edges = edges
@@ -123,6 +187,7 @@ def node_format(x):
123187
)
124188

125189
pos = nx.planar_layout(G)
190+
plt.figure()
126191
nx.draw(G, pos=pos, with_labels=True)
127192
nx.draw_networkx_edge_labels(G, pos=pos)
128-
plt.show()
193+
# plt.show()

0 commit comments

Comments
 (0)