Skip to content

Commit df2fb89

Browse files
committed
Fix examples
1 parent dee6518 commit df2fb89

File tree

7 files changed

+282
-76
lines changed

7 files changed

+282
-76
lines changed

data_prototype/containers.py

Lines changed: 156 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919
import numpy as np
2020
import pandas as pd
2121

22+
from typing import TYPE_CHECKING
23+
24+
if TYPE_CHECKING:
25+
from .conversion_edge import Graph
26+
2227

2328
class _MatplotlibTransform(Protocol):
2429
def transform(self, verts):
@@ -148,9 +153,8 @@ def compatible(a: dict[str, Desc], b: dict[str, Desc]) -> bool:
148153
class DataContainer(Protocol):
149154
def query(
150155
self,
151-
# TODO 3D?!!
152-
coord_transform: _MatplotlibTransform,
153-
size: Tuple[int, int],
156+
graph: Graph,
157+
parent_coordinates: str = "axes",
154158
/,
155159
) -> Tuple[Dict[str, Any], Union[str, int]]:
156160
"""
@@ -208,8 +212,8 @@ def __init__(self, **data):
208212

209213
def query(
210214
self,
211-
coord_transform: _MatplotlibTransform,
212-
size: Tuple[int, int],
215+
graph: Graph,
216+
parent_coordinates: str = "axes",
213217
) -> Tuple[Dict[str, Any], Union[str, int]]:
214218
return dict(self._data), self._cache_key
215219

@@ -233,8 +237,8 @@ def __init__(self, **shapes):
233237

234238
def query(
235239
self,
236-
coord_transform: _MatplotlibTransform,
237-
size: Tuple[int, int],
240+
graph: Graph,
241+
parent_coordinates: str = "axes",
238242
) -> Tuple[Dict[str, Any], Union[str, int]]:
239243
return {k: np.random.randn(*d.shape) for k, d in self._desc.items()}, str(
240244
uuid.uuid4()
@@ -301,31 +305,101 @@ def _query_hash(self, coord_transform, size):
301305

302306
def query(
303307
self,
304-
coord_transform: _MatplotlibTransform,
305-
size: Tuple[int, int],
308+
graph: Graph,
309+
parent_coordinates: str = "axes",
306310
) -> Tuple[Dict[str, Any], Union[str, int]]:
307-
hash_key = self._query_hash(coord_transform, size)
308-
if hash_key in self._cache:
309-
return self._cache[hash_key], hash_key
311+
# hash_key = self._query_hash(coord_transform, size)
312+
# if hash_key in self._cache:
313+
# return self._cache[hash_key], hash_key
314+
315+
data_lim = graph.evaluator(
316+
{
317+
"x": Desc(
318+
("N",),
319+
np.dtype(
320+
"f8",
321+
),
322+
coordinates="data",
323+
),
324+
"y": Desc(
325+
("N",),
326+
np.dtype(
327+
"f8",
328+
),
329+
coordinates="data",
330+
),
331+
},
332+
{
333+
"x": Desc(
334+
("N",),
335+
np.dtype(
336+
"f8",
337+
),
338+
coordinates=parent_coordinates,
339+
),
340+
"y": Desc(
341+
("N",),
342+
np.dtype(
343+
"f8",
344+
),
345+
coordinates=parent_coordinates,
346+
),
347+
},
348+
).inverse
349+
screen_size = graph.evaluator(
350+
{
351+
"x": Desc(
352+
("N",),
353+
np.dtype(
354+
"f8",
355+
),
356+
coordinates=parent_coordinates,
357+
),
358+
"y": Desc(
359+
("N",),
360+
np.dtype(
361+
"f8",
362+
),
363+
coordinates=parent_coordinates,
364+
),
365+
},
366+
{
367+
"x": Desc(
368+
("N",),
369+
np.dtype(
370+
"f8",
371+
),
372+
coordinates="display",
373+
),
374+
"y": Desc(
375+
("N",),
376+
np.dtype(
377+
"f8",
378+
),
379+
coordinates="display",
380+
),
381+
},
382+
)
310383

311-
xpix, ypix = size
312-
x_data, _ = coord_transform.transform(
313-
np.vstack(
314-
[
315-
np.linspace(0, 1, int(xpix) * 2),
316-
np.zeros(int(xpix) * 2),
317-
]
318-
).T
319-
).T
320-
_, y_data = coord_transform.transform(
321-
np.vstack(
322-
[
323-
np.zeros(int(ypix) * 2),
324-
np.linspace(0, 1, int(ypix) * 2),
325-
]
326-
).T
327-
).T
384+
screen_dims = screen_size.evaluate({"x": [0, 1], "y": [0, 1]})
385+
xpix, ypix = np.ceil(np.abs(np.diff(screen_dims["x"]))), np.ceil(
386+
np.abs(np.diff(screen_dims["y"]))
387+
)
328388

389+
x_data = data_lim.evaluate(
390+
{
391+
"x": np.linspace(0, 1, int(xpix) * 2),
392+
"y": np.zeros(int(xpix) * 2),
393+
}
394+
)["x"]
395+
y_data = data_lim.evaluate(
396+
{
397+
"x": np.zeros(int(ypix) * 2),
398+
"y": np.linspace(0, 1, int(ypix) * 2),
399+
}
400+
)["y"]
401+
402+
hash_key = str(uuid.uuid4())
329403
ret = self._cache[hash_key] = dict(
330404
**{k: f(x_data) for k, f in self._xfuncs.items()},
331405
**{k: f(y_data) for k, f in self._yfuncs.items()},
@@ -350,11 +424,49 @@ def __init__(self, raw_data, num_bins: int):
350424

351425
def query(
352426
self,
353-
coord_transform: _MatplotlibTransform,
354-
size: Tuple[int, int],
427+
graph: Graph,
428+
parent_coordinates: str = "axes",
355429
) -> Tuple[Dict[str, Any], Union[str, int]]:
356430
dmin, dmax = self._full_range
357-
xmin, ymin, xmax, ymax = coord_transform.transform([[0, 0], [1, 1]]).flatten()
431+
432+
data_lim = graph.evaluator(
433+
{
434+
"x": Desc(
435+
("N",),
436+
np.dtype(
437+
"f8",
438+
),
439+
coordinates="data",
440+
),
441+
"y": Desc(
442+
("N",),
443+
np.dtype(
444+
"f8",
445+
),
446+
coordinates="data",
447+
),
448+
},
449+
{
450+
"x": Desc(
451+
("N",),
452+
np.dtype(
453+
"f8",
454+
),
455+
coordinates=parent_coordinates,
456+
),
457+
"y": Desc(
458+
("N",),
459+
np.dtype(
460+
"f8",
461+
),
462+
coordinates=parent_coordinates,
463+
),
464+
},
465+
).inverse
466+
467+
pts = data_lim.evaluate({"x": (0, 1), "y": (0, 1)})
468+
xmin, xmax = pts["x"]
469+
ymin, ymax = pts["y"]
358470

359471
xmin, xmax = np.clip([xmin, xmax], dmin, dmax)
360472
hash_key = hash((xmin, xmax))
@@ -398,8 +510,8 @@ def __init__(self, series: pd.Series, *, index_name: str, col_name: str):
398510

399511
def query(
400512
self,
401-
coord_transform: _MatplotlibTransform,
402-
size: Tuple[int, int],
513+
graph: Graph,
514+
parent_coordinates: str = "axes",
403515
) -> Tuple[Dict[str, Any], Union[str, int]]:
404516
return {
405517
self._index_name: self._data.index.values,
@@ -440,8 +552,8 @@ def __init__(
440552

441553
def query(
442554
self,
443-
coord_transform: _MatplotlibTransform,
444-
size: Tuple[int, int],
555+
graph: Graph,
556+
parent_coordinates: str = "axes",
445557
) -> Tuple[Dict[str, Any], Union[str, int]]:
446558
ret = {}
447559
if self._index_name is not None:
@@ -463,10 +575,10 @@ def __init__(self, data: DataContainer, mapping: Dict[str, str]):
463575

464576
def query(
465577
self,
466-
coord_transform: _MatplotlibTransform,
467-
size: Tuple[int, int],
578+
graph: Graph,
579+
parent_coordinates: str = "axes",
468580
) -> Tuple[Dict[str, Any], Union[str, int]]:
469-
base, cache_key = self._data.query(coord_transform, size)
581+
base, cache_key = self._data.query(graph, parent_coordinates)
470582
return {v: base[k] for k, v in self._mapping.items()}, cache_key
471583

472584
def describe(self):
@@ -481,13 +593,13 @@ def __init__(self, *data: DataContainer):
481593

482594
def query(
483595
self,
484-
coord_transform: _MatplotlibTransform,
485-
size: Tuple[int, int],
596+
graph: Graph,
597+
parent_coordinates: str = "axes",
486598
) -> Tuple[Dict[str, Any], Union[str, int]]:
487599
cache_keys = []
488600
ret = {}
489601
for data in self._datas:
490-
base, cache_key = data.query(coord_transform, size)
602+
base, cache_key = data.query(graph, parent_coordinates)
491603
ret.update(base)
492604
cache_keys.append(cache_key)
493605
return ret, hash(tuple(cache_keys))
@@ -499,8 +611,8 @@ def describe(self):
499611
class WebServiceContainer:
500612
def query(
501613
self,
502-
coord_transform: _MatplotlibTransform,
503-
size: Tuple[int, int],
614+
graph: Graph,
615+
parent_coordinates: str = "axes",
504616
) -> Tuple[Dict[str, Any], Union[str, int]]:
505617
def hit_some_database():
506618
{}, "1"

data_prototype/conversion_edge.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@ class Edge:
1414
name: str
1515
input: dict[str, Desc]
1616
output: dict[str, Desc]
17-
invertable: bool = False
17+
invertable: bool = True
1818

1919
def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
2020
return input
2121

2222
@property
2323
def inverse(self) -> "Edge":
24-
raise NotImplementedError
24+
return Edge(self.name + "_r", self.output, self.input)
2525

2626

2727
@dataclass
@@ -42,14 +42,23 @@ def from_edges(cls, name: str, edges: Sequence[Edge], output: dict[str, Desc]):
4242

4343
def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
4444
for edge in self.edges:
45-
input |= edge.evaluate(**{k: input[k] for k in edge.input})
45+
print(input)
46+
input |= edge.evaluate({k: input[k] for k in edge.input})
47+
print(input)
4648
return {k: input[k] for k in self.output}
4749

50+
@property
51+
def inverse(self) -> "SequenceEdge":
52+
return SequenceEdge.from_edges(
53+
self.name + "_r", [e.inverse for e in self.edges[::-1]], self.input
54+
)
55+
4856

4957
@dataclass
5058
class FuncEdge(Edge):
5159
# TODO: more explicit callable boundaries?
5260
func: Callable = lambda: {}
61+
inverse_func: Callable | None = None
5362

5463
@classmethod
5564
def from_func(
@@ -58,6 +67,7 @@ def from_func(
5867
func: Callable,
5968
input: str | dict[str, Desc],
6069
output: str | dict[str, Desc],
70+
inverse: Callable | None = None,
6171
):
6272
# dtype/shape is reductive here, but I like the idea of being able to just
6373
# supply a function and the input/output coordinates for many things
@@ -69,7 +79,7 @@ def from_func(
6979
if isinstance(output, str):
7080
output = {k: Desc(("N",), np.dtype("f8"), output) for k in input.keys()}
7181

72-
return cls(name, input, output, False, func)
82+
return cls(name, input, output, inverse is not None, func, inverse)
7383

7484
def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
7585
res = self.func(**{k: input[k] for k in self.input})
@@ -88,10 +98,16 @@ def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
8898
return {k: res for k in self.output}
8999
raise RuntimeError("Output of function does not match expected output")
90100

101+
@property
102+
def inverse(self) -> "FuncEdge":
103+
return FuncEdge.from_func(
104+
self.name + "_r", self.inverse_func, self.output, self.input, self.func
105+
)
106+
91107

92108
@dataclass
93109
class TransformEdge(Edge):
94-
transform: Transform | None = None
110+
transform: Transform | Callable[[], Transform] | None = None
95111

96112
# TODO: helper for common cases/validation?
97113

@@ -101,10 +117,29 @@ def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
101117
# especially if initially given as stacked
102118
if self.transform is None:
103119
return input
120+
elif isinstance(self.transform, Callable):
121+
trf = self.transform()
122+
else:
123+
trf = self.transform
104124
inp = np.stack([input[k] for k in self.input], axis=-1)
105-
outp = self.transform.transform(inp)
125+
outp = trf.transform(inp)
106126
return {k: v for k, v in zip(self.output, outp.T)}
107127

128+
@property
129+
def inverse(self) -> "TransformEdge":
130+
if isinstance(self.transform, Callable):
131+
return TransformEdge(
132+
self.name + "_r",
133+
self.output,
134+
self.input,
135+
True,
136+
lambda: self.transform().inverted(),
137+
)
138+
139+
return TransformEdge(
140+
self.name + "_r", self.output, self.input, True, self.transform.inverted()
141+
)
142+
108143

109144
class Graph:
110145
def __init__(self, edges: Sequence[Edge]):

0 commit comments

Comments
 (0)