Skip to content

Commit 1074f2f

Browse files
Add ndarray demo
1 parent e4e7fbc commit 1074f2f

File tree

11 files changed

+257
-25
lines changed

11 files changed

+257
-25
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ crate-type = ["cdylib"]
1212
# Use unreleased version to depend on signature improvements https://github.com/PyO3/pyo3/pull/2702
1313
pyo3 = { version = "0.18.1", features = ["extension-module"] }
1414
# egglog = { git = "https://github.com/egraphs-good/egglog", rev = "39b199d9bfce9cc47d0c54977279c5b04231e717" }
15-
egglog = { git = "https://github.com/saulshanabrook/egg-smol", rev = "de202c4930a9a983361bd7af10f8d6c1a3740c08" }
15+
egglog = { git = "https://github.com/saulshanabrook/egg-smol", rev = "989f0527d46550f05562a3a8c0e87a0b1280a930" }
1616

1717
# egglog = { path = "../egg-smol" }
1818
pyo3-log = "0.8.1"

Presentation.ipynb

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,5 @@
11
{
22
"cells": [
3-
{
4-
"cell_type": "markdown",
5-
"id": "a74307d3-b810-4f9d-9663-d5b0030357ae",
6-
"metadata": {},
7-
"source": [
8-
"* Campy\n",
9-
"* Add more complicated example\n",
10-
"* Change fonts\n",
11-
"* Make sure examples are done"
12-
]
13-
},
143
{
154
"cell_type": "markdown",
165
"id": "a611c138-1afd-4d6f-9578-4f358d2438eb",

docs/changelog.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,15 @@ _This project uses semantic versioning. Before 1.0.0, this means that every brea
1111
- Added `Set` sort and removed set method from `Map`
1212
- Added `Vec` sort
1313
- Added support for variable args for builtin functions, to use in creation of `Vec` and `Set` sorts.
14+
- Added suport for joining `String`s
1415
- Switch generated egg names to use `.` as seperate (i.e. `Math.__add__`) instead of `_` (i.e. `Math___add__`)
1516
- Adds support for modules to define functions/sorts/rules without executing them, for reuse in other modules
1617
- Moved simplifying and running rulesets to the `run` and `simplify` methods on the `EGraph` from those methods on the `Ruleset` since we can now create `Rulset`s for modules which don't have an EGraph attached and can't be run
18+
- Fixed extracting classmethods which required generic args to cls
19+
- Added support for alternative way of creating variables using functions
20+
- Add NDarray example
21+
- Render EGraphs with `graphviz` in the notebook
22+
- Add `%%egglog` magic to the notebook
1723

1824
## 0.4.0 (2023-05-03)
1925

docs/reference/egglog-translation.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,21 @@ Since it uses a fluent API, static type checkers can verify that the type of the
328328

329329
The `(birewrite ...)` command in egglog is syntactic sugar for creating two rewrites, one in each direction. In Python, we can use the `birewrite(expr).to(expr, *when)` function to create two rules that rewrite in each direction.
330330

331+
### Using funcitons to define vars
332+
333+
Instead of defining variables with `vars_`, we can also use functions to define variables. This can be more succinct
334+
and also will make sure the variables won't be used outside of the scope of the function.
335+
336+
```{code-cell} python
337+
# egg: (rewrite (Mul a b) (Mul b a))
338+
# egg: (rewrite (Add a b) (Add b a))
339+
340+
@egraph.register
341+
def _math(a: Math, b: Math)
342+
yield rewrite(a * b).to(b * a)
343+
yield rewrite(a + b).to(b + a)
344+
```
345+
331346
## Running
332347

333348
To run the egraph, we can use the `egraph.run()` function. This will run until a fixed point is reached, or until a timeout is reached.

python/egglog/builtins.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"Rational",
2222
"Set",
2323
"Vec",
24+
"join",
2425
]
2526

2627

@@ -147,13 +148,19 @@ def max(self, other: f64Like) -> f64: # type: ignore[empty-body]
147148
...
148149

149150

151+
StringLike = Union[str, "String"]
152+
153+
150154
@BUILTINS.class_
151155
class String(BaseExpr):
152156
def __init__(self, value: str):
153157
...
154158

155159

156-
StringLike = Union[str, String]
160+
@BUILTINS.function(egg_fn="+")
161+
def join(*strings: StringLike) -> String: # type: ignore[empty-body]
162+
...
163+
157164

158165
T = TypeVar("T", bound=BaseExpr)
159166
V = TypeVar("V", bound=BaseExpr)
@@ -328,3 +335,11 @@ def not_contains(self, value: T) -> Unit: # type: ignore[empty-body]
328335
@BUILTINS.method(egg_fn="vec-contains")
329336
def contains(self, value: T) -> Unit: # type: ignore[empty-body]
330337
...
338+
339+
@BUILTINS.method(egg_fn="vec-length")
340+
def length(self) -> i64: # type: ignore[empty-body]
341+
...
342+
343+
@BUILTINS.method(egg_fn="vec-get")
344+
def __getitem__(self, index: i64Like) -> T: # type: ignore[empty-body]
345+
...

python/egglog/declarations.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,18 @@ def from_egg(cls, mod_decls: ModuleDeclarations, call: bindings.Call) -> TypedEx
536536

537537
# Find the first callable ref that matches the call
538538
for callable_ref in mod_decls.get_callable_refs(call.name):
539-
tcs = TypeConstraintSolver()
539+
# If this is a classmethod, we might need the type params that were bound for this type
540+
# egglog currently only allows one instantiated type of any generic sort to be used in any program
541+
# So we just lookup what args were registered for thsi sort
542+
if isinstance(callable_ref, ClassMethodRef):
543+
for registered_tp in mod_decls._decl._type_ref_to_egg_sort.keys():
544+
if registered_tp.name == callable_ref.class_name:
545+
tcs = TypeConstraintSolver.from_type_parameters(registered_tp.args)
546+
break
547+
else:
548+
raise ValueError(f"Could not find type parameters for class {callable_ref.class_name}")
549+
else:
550+
tcs = TypeConstraintSolver()
540551
fn_decl = mod_decls.get_function_decl(callable_ref)
541552
return_tp = tcs.infer_return_type(fn_decl.arg_types, fn_decl.return_type, fn_decl.var_arg_type, arg_types)
542553
return TypedExprDecl(return_tp, cls(callable_ref, tuple(results)))

python/egglog/egraph.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
overload,
2424
)
2525

26+
import graphviz
2627
from egglog.declarations import Declarations
2728
from typing_extensions import ParamSpec, get_args, get_origin
2829

@@ -214,7 +215,10 @@ def _class(
214215
# If this is an i64, use the runtime class for the alias so that i64Like is resolved properly
215216
# Otherwise, this might be a Map in which case pass in the original cls so that we
216217
# can do Map[T, V] on it, which is not allowed on the runtime class
217-
cls_type_and_name=(RuntimeClass(self._mod_decls, "i64") if cls_name == "i64" else cls, cls_name),
218+
cls_type_and_name=(
219+
RuntimeClass(self._mod_decls, cls_name) if cls_name in {"i64", "String"} else cls,
220+
cls_name,
221+
),
218222
)
219223

220224
# Register != as a method so we can print it as a string
@@ -450,10 +454,15 @@ def _resolve_type_annotation(
450454
return class_to_ref(tp).to_var()
451455
raise TypeError(f"Unexpected type annotation {tp}")
452456

453-
def register(self, *commands: CommandLike) -> None:
457+
def register(self, command_or_generator: CommandLike | CommandGenerator, *commands: CommandLike) -> None:
454458
"""
455459
Registers any number of rewrites or rules.
456460
"""
461+
if isinstance(command_or_generator, FunctionType):
462+
assert not commands
463+
commands = tuple(_command_generator(command_or_generator))
464+
else:
465+
commands = (cast(CommandLike, command_or_generator), *commands)
457466
self._process_commands(_command_like(command)._to_egg_command(self._mod_decls) for command in commands)
458467

459468
def ruleset(self, name: str) -> Ruleset:
@@ -587,9 +596,12 @@ def _repr_mimebundle_(self, *args, **kwargs):
587596
"""
588597
Returns the graphviz representation of the e-graph.
589598
"""
590-
import graphviz
591599

592-
return graphviz.Source(self._egraph.to_graphviz_string())._repr_mimebundle_(*args, **kwargs)
600+
return self.graphviz._repr_mimebundle_(*args, **kwargs)
601+
602+
@property
603+
def graphviz(self) -> graphviz.Source:
604+
return graphviz.Source(self._egraph.to_graphviz_string())
593605

594606
def display(self):
595607
"""
@@ -982,6 +994,18 @@ def _command_like(command_like: CommandLike) -> Command:
982994
return command_like
983995

984996

997+
CommandGenerator = Callable[..., Iterable[Command]]
998+
999+
1000+
def _command_generator(gen: CommandGenerator) -> Iterable[Command]:
1001+
"""
1002+
Calls the function with variables of the type and name of the arguments.
1003+
"""
1004+
hints = get_type_hints(gen)
1005+
args = (_var(p.name, hints[p.name]) for p in signature(gen).parameters.values())
1006+
return gen(*args)
1007+
1008+
9851009
ActionLike = Union[Action, BaseExpr]
9861010

9871011

python/egglog/examples/lambda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def assert_simplifies(left: BaseExpr, right: BaseExpr) -> None:
181181
egraph.run(30)
182182
res = egraph.extract(left)
183183
print(f"{left}{res}")
184-
egraph.check(eq(right).to(right))
184+
egraph.check(eq(left).to(right))
185185

186186

187187
assert_simplifies((Term.val(Val(1))).eval(), Val(1))

python/egglog/examples/ndarrays.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
"""
2+
N-Dimensional Arrays
3+
====================
4+
"""
5+
# mypy: disable-error-code=empty-body
6+
from __future__ import annotations
7+
8+
from egglog import *
9+
10+
# Story
11+
# Start with eager arrays
12+
# Then move onto lazy arrays.
13+
14+
# Then show how the point is we can do this with different packages...
15+
16+
# 1. Wonderful ndarray library with execution on my cpu
17+
# 2. Come along with that allows generation of code for another platform, thats batched
18+
# 3. Add cross product, works on both platforms
19+
# 4. Add specialisation for lazy platform.
20+
21+
# 1. Different people to do different parts, without explicit coordination... No need to ask for permission.
22+
# 2. Can
23+
24+
25+
egraph = EGraph()
26+
27+
28+
@egraph.class_
29+
class Value(BaseExpr):
30+
def __init__(self, v: i64Like) -> None:
31+
...
32+
33+
def __mul__(self, other: Value) -> Value:
34+
...
35+
36+
def __add__(self, other: Value) -> Value:
37+
...
38+
39+
40+
i, j = vars_("i j", i64)
41+
egraph.register(
42+
rewrite(Value(i) * Value(j)).to(Value(i * j)),
43+
rewrite(Value(i) + Value(j)).to(Value(i + j)),
44+
)
45+
46+
47+
@egraph.class_
48+
class Values(BaseExpr):
49+
def __init__(self, v: Vec[Value]) -> None:
50+
...
51+
52+
def __getitem__(self, idx: Value) -> Value:
53+
...
54+
55+
def length(self) -> Value:
56+
...
57+
58+
def concat(self, other: Values) -> Values:
59+
...
60+
61+
62+
@egraph.register
63+
def _values(vs: Vec[Value], other: Vec[Value]):
64+
yield rewrite(Values(vs)[Value(i)]).to(vs[i])
65+
yield rewrite(Values(vs).length()).to(Value(vs.length()))
66+
yield rewrite(Values(vs).concat(Values(other))).to(Values(vs.append(other)))
67+
# yield rewrite(l.concat(r).length()).to(l.length() + r.length())
68+
# yield rewrite(l.concat(r)[idx])
69+
70+
71+
@egraph.class_
72+
class NDArray(BaseExpr):
73+
"""
74+
An n-dimensional array.
75+
"""
76+
77+
def __getitem__(self, idx: Values) -> Value:
78+
...
79+
80+
def shape(self) -> Values:
81+
...
82+
83+
84+
@egraph.function
85+
def arange(n: Value) -> NDArray:
86+
...
87+
88+
89+
@egraph.register
90+
def _ndarray_arange(n: Value, idx: Values):
91+
yield rewrite(arange(n).shape()).to(Values(Vec(n)))
92+
yield rewrite(arange(n)[idx]).to(idx[Value(0)])
93+
94+
95+
def assert_simplifies(left: BaseExpr, right: BaseExpr) -> None:
96+
"""
97+
Simplify and print
98+
"""
99+
egraph.register(left)
100+
egraph.run(30)
101+
res = egraph.extract(left)
102+
print(f"{left} == {right}{res}")
103+
egraph.check(eq(left).to(right))
104+
105+
106+
assert_simplifies(arange(Value(10)).shape(), Values(Vec(Value(10))))
107+
assert_simplifies(arange(Value(10))[Values(Vec(Value(0)))], Value(0))
108+
assert_simplifies(arange(Value(10))[Values(Vec(Value(1)))], Value(1))
109+
110+
111+
@egraph.function
112+
def py_value(s: StringLike) -> Value:
113+
...
114+
115+
116+
@egraph.register
117+
def _py_value(l: String, r: String):
118+
yield rewrite(py_value(l) + py_value(r)).to(py_value(join(l, " + ", r)))
119+
yield rewrite(py_value(l) * py_value(r)).to(py_value(join(l, " * ", r)))
120+
121+
122+
@egraph.function
123+
def py_values(s: StringLike) -> Values:
124+
...
125+
126+
127+
@egraph.register
128+
def _py_values(l: String, r: String):
129+
yield rewrite(py_values(l)[py_value(r)]).to(py_value(join(l, "[", r, "]")))
130+
yield rewrite(py_values(l).length()).to(py_value(join("len(", l, ")")))
131+
yield rewrite(py_values(l).concat(py_values(r))).to(py_values(join(l, " + ", r)))
132+
133+
134+
@egraph.function
135+
def py_ndarray(s: StringLike) -> NDArray:
136+
...
137+
138+
139+
@egraph.register
140+
def _py_ndarray(l: String, r: String):
141+
yield rewrite(py_ndarray(l)[py_values(r)]).to(py_value(join(l, "[", r, "]")))
142+
yield rewrite(py_ndarray(l).shape()).to(py_values(join(l, ".shape")))
143+
yield rewrite(arange(py_value(l))).to(py_ndarray(join("np.arange(", l, ")")))
144+
145+
146+
assert_simplifies(py_ndarray("x").shape(), py_values("x.shape"))
147+
assert_simplifies(arange(py_value("x"))[py_values("y")], py_value("np.arange(x)[y]"))
148+
# assert_simplifies(arange(py_value("x"))[py_values("y")], py_value("y[0]"))
149+
150+
151+
@egraph.function
152+
def cross(l: NDArray, r: NDArray) -> NDArray:
153+
...
154+
155+
156+
@egraph.register
157+
def _cross(l: NDArray, r: NDArray, idx: Values):
158+
yield rewrite(cross(l, r).shape()).to(l.shape().concat(r.shape()))
159+
yield rewrite(cross(l, r)[idx]).to(l[idx] * r[idx])
160+
161+
162+
assert_simplifies(cross(arange(Value(10)), arange(Value(11))).shape(), Values(Vec(Value(10), Value(11))))
163+
assert_simplifies(cross(py_ndarray("x"), py_ndarray("y")).shape(), py_values("x.shape + y.shape"))
164+
assert_simplifies(cross(py_ndarray("x"), py_ndarray("y"))[py_values("idx")], py_value("x[idx] * y[idx]"))
165+
166+
167+
@egraph.register
168+
def _cross_py(l: String, r: String):
169+
yield rewrite(cross(py_ndarray(l), py_ndarray(r))).to(py_ndarray(join("np.multiply.outer(", l, ", ", r, ")")))
170+
171+
172+
assert_simplifies(cross(py_ndarray("x"), py_ndarray("y"))[py_values("idx")], py_value("np.multiply.outer(x, y)[idx]"))

0 commit comments

Comments
 (0)