Skip to content

Commit dff1bb7

Browse files
Add support for high level
1 parent ddcc5c5 commit dff1bb7

17 files changed

+1431
-1220
lines changed

docs/explanation/compared_to_rust.md

Lines changed: 30 additions & 222 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ egg text version of this from the tests is:
4444
(check (= expr1 expr2))
4545
```
4646

47-
## Text API
47+
## Low Level API
4848

4949
One way to run this in Python is to parse the text and run it similar to how the
50-
egg CLI works:
50+
egg-smol CLI works:
5151

5252
```{code-cell} python
5353
from egg_smol.bindings import *
@@ -64,8 +64,6 @@ eqsat_basic = """(datatype Math
6464
(define expr2 (Add (Num 6) (Mul (Num 2) (Var "x"))))
6565
6666
67-
;; (rule ((= __root (Add a b)))
68-
;; ((union __root (Add b a)))
6967
(rewrite (Add a b)
7068
(Add b a))
7169
(rewrite (Mul a (Add b c))
@@ -79,236 +77,28 @@ eqsat_basic = """(datatype Math
7977
(check (= expr1 expr2))"""
8078
8179
egraph = EGraph()
82-
egraph.parse_and_run_program(eqsat_basic)
80+
commands = egraph.parse_program(eqsat_basic)
81+
egraph.run_program(*commands)
8382
```
8483

85-
## Low level bindings API
86-
87-
However, this isn't the most friendly for Python users. Instead, we can use the
88-
low level APIs that mirror the rust APIs to build the same egraph:
84+
The commands are a representation which is close the AST of the egg-smol text language. We
85+
can see this by printing the commands:
8986

9087
```{code-cell} python
91-
egraph = EGraph()
92-
egraph.declare_sort("Math")
93-
egraph.declare_constructor(Variant("Num", ["i64"]), "Math")
94-
egraph.declare_constructor(Variant("Var", ["String"]), "Math")
95-
egraph.declare_constructor(Variant("Add", ["Math", "Math"]), "Math")
96-
egraph.declare_constructor(Variant("Mul", ["Math", "Math"]), "Math")
97-
98-
# (define expr1 (Mul (Num 2) (Add (Var "x") (Num 3))))
99-
egraph.define(
100-
"expr1",
101-
Call(
102-
"Mul",
103-
[
104-
Call(
105-
"Num",
106-
[
107-
Lit(Int(2)),
108-
],
109-
),
110-
Call(
111-
"Add",
112-
[
113-
Call(
114-
"Var",
115-
[
116-
Lit(String("x")),
117-
],
118-
),
119-
Call(
120-
"Num",
121-
[
122-
Lit(Int(3)),
123-
],
124-
),
125-
],
126-
),
127-
],
128-
),
129-
)
130-
# (define expr2 (Add (Num 6) (Mul (Num 2) (Var "x"))))
131-
egraph.define(
132-
"expr2",
133-
Call(
134-
"Add",
135-
[
136-
Call(
137-
"Num",
138-
[
139-
Lit(Int(6)),
140-
],
141-
),
142-
Call(
143-
"Mul",
144-
[
145-
Call(
146-
"Num",
147-
[
148-
Lit(Int(2)),
149-
],
150-
),
151-
Call(
152-
"Var",
153-
[
154-
Lit(String("x")),
155-
],
156-
),
157-
],
158-
),
159-
],
160-
),
161-
)
162-
# (rewrite (Add a b)
163-
# (Add b a))
164-
egraph.add_rewrite(
165-
Rewrite(
166-
Call(
167-
"Add",
168-
[
169-
Var("a"),
170-
Var("b"),
171-
],
172-
),
173-
Call(
174-
"Add",
175-
[
176-
Var("b"),
177-
Var("a"),
178-
],
179-
),
180-
)
181-
)
182-
# (rewrite (Mul a (Add b c))
183-
# (Add (Mul a b) (Mul a c)))
184-
egraph.add_rewrite(
185-
Rewrite(
186-
Call(
187-
"Mul",
188-
[
189-
Var("a"),
190-
Call(
191-
"Add",
192-
[
193-
Var("b"),
194-
Var("c"),
195-
],
196-
),
197-
],
198-
),
199-
Call(
200-
"Add",
201-
[
202-
Call(
203-
"Mul",
204-
[
205-
Var("a"),
206-
Var("b"),
207-
],
208-
),
209-
Call(
210-
"Mul",
211-
[
212-
Var("a"),
213-
Var("c"),
214-
],
215-
),
216-
],
217-
),
218-
)
219-
)
220-
221-
# (rewrite (Add (Num a) (Num b))
222-
# (Num (+ a b)))
223-
lhs = Call(
224-
"Add",
225-
[
226-
Call(
227-
"Num",
228-
[
229-
Var("a"),
230-
],
231-
),
232-
Call(
233-
"Num",
234-
[
235-
Var("b"),
236-
],
237-
),
238-
],
239-
)
240-
rhs = Call(
241-
"Num",
242-
[
243-
Call(
244-
"+",
245-
[
246-
Var("a"),
247-
Var("b"),
248-
],
249-
)
250-
],
251-
)
252-
egraph.add_rewrite(Rewrite(lhs, rhs))
253-
254-
# (rewrite (Mul (Num a) (Num b))
255-
# (Num (* a b)))
256-
lhs = Call(
257-
"Mul",
258-
[
259-
Call(
260-
"Num",
261-
[
262-
Var("a"),
263-
],
264-
),
265-
Call(
266-
"Num",
267-
[
268-
Var("b"),
269-
],
270-
),
271-
],
272-
)
273-
rhs = Call(
274-
"Num",
275-
[
276-
Call(
277-
"*",
278-
[
279-
Var("a"),
280-
Var("b"),
281-
],
282-
)
283-
],
284-
)
285-
egraph.add_rewrite(Rewrite(lhs, rhs))
286-
287-
egraph.run_rules(10)
288-
egraph.check_fact(
289-
Eq(
290-
[
291-
Var("expr1"),
292-
Var("expr2"),
293-
]
294-
)
295-
)
88+
for command in commands:
89+
print(command)
29690
```
29791

298-
This has a couple of advantages over the text version. Users now know what types
299-
of functions are available to them and also it can be statically type checked with MyPy,
300-
to make sure that the types are correct.
301-
302-
However, it is much more verbose than the text version!
303-
30492
## High level API
30593

306-
So would it be possible to make an API that:
94+
The high level API builds on this API and is designed to:
30795

30896
1. Statically type checks as much as possible with MyPy
309-
2. Is concise to write
97+
2. Be concise to write
31098
3. Feels "pythonic"
31199

100+
Here is the same example using the high level API:
101+
312102
```{code-cell} python
313103
from __future__ import annotations
314104
@@ -351,3 +141,21 @@ egraph.run(10)
351141
352142
egraph.check(eq(expr1).to(expr2))
353143
```
144+
145+
### Mapping of low level to high level
146+
147+
Here are a number of the low level commands, with how they map to the high levle API:
148+
149+
- `(datatype Math ...)` -> `@egraph.class_` on a Python class. Internally, each method and classmethod are registered as functions, not as `Variant`s of the datatype, but the end result is the same.
150+
- `(set-option enable_proofs 1)` -> Not supported
151+
- `(declare True Bool)` -> As a class variable `True: Bool` or as a constant `True_ = egraph.constant("True", Bool)`. Internally, we don't actually use the `Constant` command but instead map constants to nullary functions which are immediately evaluated. This is how the `Constant` command is desugared in egg-smol anyways.
152+
- `(define expr1 ...)` -> `expr1 = egraph.define("expr1", ...)` or just `expr1 = ...` if it doesn't need to be added to the e-graph.
153+
- `(sort MyMap (Map i64 String))` -> `MyMap = Map[i64, String]`. We can use the normal Python generic typing syntax. Internally, when this is used in a type definition, we would create a new sort with the name `Map__i64__String`.
154+
- `(function f ...)` -> `@egraph.function` on a Python function with no body.
155+
- `(ruleset x)` -> `x = egraph.ruleset("x")`
156+
- `(rule (f1 f2) (a1 a2))` -> `egraph.register(rule(f1, f2).then(a1, a2))`
157+
- `(run 10 :until f :rulset x)` -> `egraph.run(10, until=f, ruleset=x)`
158+
159+
Facts
160+
161+
Actions

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ check_untyped_defs = true
5757
strict_equality = true
5858
warn_unused_configs = true
5959
allow_redefinition = true
60+
enable_incomplete_feature = ["Unpack", "TypeVarTuple"]
6061

6162
[tool.maturin]
6263
python-source = "python"

python/egg_smol/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from . import config # noqa: F401
1+
from . import config as configuration # noqa: F401
22
from .builtins import * # noqa: F401
33
from .egraph import * # noqa: F401
4-
from .registry import * # noqa: F401

python/egg_smol/bindings.pyi

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@ from typing import Optional
44

55
from typing_extensions import final
66

7+
HIGH_COST: int
8+
79
@final
810
class EGraph:
911
def __init__(self, fact_directory: str | Path | None = None, seminaive=True) -> None: ...
10-
def parse_program(self, input: str) -> list[_Command]: ...
12+
def parse_program(self, __input: str, /) -> list[_Command]: ...
1113
def run_program(self, *commands: _Command) -> list[str]: ...
1214
def take_extract_report(self) -> Optional[ExtractReport]: ...
1315
def take_run_report(self) -> Optional[RunReport]: ...
@@ -282,12 +284,14 @@ class RuleCommand:
282284

283285
@final
284286
class RewriteCommand:
287+
# TODO: Rename to ruleset
285288
name: str
286289
rewrite: Rewrite
287290
def __init__(self, name: str, rewrite: Rewrite) -> None: ...
288291

289292
@final
290293
class BiRewriteCommand:
294+
# TODO: Rename to ruleset
291295
name: str
292296
rewrite: Rewrite
293297
def __init__(self, name: str, rewrite: Rewrite) -> None: ...
@@ -380,6 +384,7 @@ _Command = (
380384
| Sort
381385
| Function
382386
| Define
387+
| AddRuleset
383388
| RuleCommand
384389
| RewriteCommand
385390
| BiRewriteCommand

0 commit comments

Comments
 (0)