Skip to content

Commit 623aa74

Browse files
committed
Initial rework of conversions as edges
1 parent 2b33f65 commit 623aa74

File tree

2 files changed

+146
-3
lines changed

2 files changed

+146
-3
lines changed

data_prototype/containers.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from dataclasses import dataclass
24
from typing import (
35
Protocol,
@@ -38,10 +40,8 @@ class Desc:
3840
# - what is the relative size to the other variable values (N vs N+1)
3941
# We are probably going to have to implement a DSL for this (😞)
4042
shape: ShapeSpec
41-
# TODO: is using a string better?
4243
dtype: np.dtype
43-
# TODO: do we want to include this at this level? "naive" means unit-unaware.
44-
units: str = "naive"
44+
coordinates: str = "naive"
4545

4646
@staticmethod
4747
def validate_shapes(
@@ -129,6 +129,21 @@ def validate_shapes(
129129
)
130130
return None
131131

132+
@staticmethod
133+
def compatible(a: dict[str, Desc], b: dict[str, Desc]) -> bool:
134+
"""Determine if ``a`` is a valid input for ``b``.
135+
136+
Note: ``a`` _may_ have additional keys.
137+
"""
138+
try:
139+
Desc.validate_shapes(b, a)
140+
except (KeyError, ValueError):
141+
return False
142+
for k, v in b.items():
143+
if a[k].coordinates != v.coordinates:
144+
return False
145+
return True
146+
132147

133148
class DataContainer(Protocol):
134149
def query(

data_prototype/conversion_edge.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
from collections.abc import Sequence
2+
from dataclasses import dataclass
3+
from typing import Any
4+
5+
from .containers import Desc
6+
7+
8+
@dataclass
9+
class Edge:
10+
name: str
11+
input: dict[str, Desc]
12+
output: dict[str, Desc]
13+
invertable: bool = False
14+
15+
def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
16+
return input
17+
18+
@property
19+
def inverse(self) -> "Edge":
20+
raise NotImplementedError
21+
22+
23+
@dataclass
24+
class SequenceEdge(Edge):
25+
edges: Sequence[Edge] = ()
26+
27+
@classmethod
28+
def from_edges(cls, name: str, edges: Sequence[Edge], output: dict[str, Desc]):
29+
input = {}
30+
intermediates = {}
31+
invertable = True
32+
for edge in edges:
33+
input |= {k: v for k, v in edge.input.items() if k not in intermediates}
34+
intermediates |= edge.output
35+
if not edge.invertable:
36+
invertable = False
37+
return cls(name, input, output, invertable, edges)
38+
39+
def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
40+
for edge in self.edges:
41+
input |= edge.evaluate(**{k: input[k] for k in edge.input})
42+
return {k: input[k] for k in self.output}
43+
44+
45+
class Graph:
46+
def __init__(self, edges: Sequence[Edge]):
47+
self._edges = edges
48+
# TODO: precompute some internal representation?
49+
# - Nodes between edges, potentially in discrete subgraphs
50+
# - Inversions are not included right now
51+
52+
def evaluator(self, input: dict[str, Desc], output: dict[str, Desc]) -> Edge:
53+
# May wish to solve for each output independently
54+
# Probably can be smarter here and prune more effectively.
55+
q: list[tuple[dict[str, Desc], tuple[Edge, ...]]] = [(input, ())]
56+
57+
def trace(x: dict[str, Desc]) -> tuple[tuple[str, str], ...]:
58+
return tuple(sorted([(k, v.coordinates) for k, v in x.items()]))
59+
60+
explored: set[tuple[tuple[str, str], ...]] = set()
61+
explored.add(trace(input))
62+
edges = ()
63+
while q:
64+
v, edges = q.pop()
65+
if Desc.compatible(v, output):
66+
break
67+
for e in self._edges:
68+
if Desc.compatible(v, e.input):
69+
w = (v | e.output, (*edges, e))
70+
w_trace = trace(w[0])
71+
if w_trace in explored:
72+
# This may need to be more explicitly checked...
73+
# May not accurately be checking what we consider "in"
74+
continue
75+
explored.add(w_trace)
76+
q.append(w)
77+
else:
78+
# TODO: case where non-linear solving is needed
79+
raise NotImplementedError(
80+
"This may be possible, but is not a simple case already considered"
81+
)
82+
if len(edges) == 0:
83+
return Edge("noop", input, output)
84+
elif len(edges) == 1:
85+
return edges[0]
86+
else:
87+
return SequenceEdge.from_edges("eval", edges, output)
88+
89+
def visualize(self, input: dict[str, Desc] | None = None):
90+
import networkx as nx
91+
import matplotlib.pyplot as plt
92+
from pprint import pformat
93+
94+
def node_format(x):
95+
return pformat({k: v.coordinates for k, v in x.items()})
96+
97+
G = nx.DiGraph()
98+
99+
if input is not None:
100+
q: list[dict[str, Desc]] = [input]
101+
explored: set[tuple[tuple[str, str], ...]] = set()
102+
explored.add(tuple(sorted(((k, v.coordinates) for k, v in q[0].items()))))
103+
G.add_node(node_format(q[0]))
104+
while q:
105+
n = q.pop()
106+
for e in self._edges:
107+
if Desc.compatible(n, e.input):
108+
w = n | e.output
109+
if node_format(w) not in G:
110+
G.add_node(node_format(w))
111+
explored.add(
112+
tuple(
113+
sorted(((k, v.coordinates) for k, v in w.items()))
114+
)
115+
)
116+
q.append(w)
117+
if node_format(w) != node_format(n):
118+
G.add_edge(node_format(n), node_format(w), name=e.name)
119+
else:
120+
for edge in self._edges:
121+
G.add_edge(
122+
node_format(edge.input), node_format(edge.output), name=edge.name
123+
)
124+
125+
pos = nx.planar_layout(G)
126+
nx.draw(G, pos=pos, with_labels=True)
127+
nx.draw_networkx_edge_labels(G, pos=pos)
128+
plt.show()

0 commit comments

Comments
 (0)