Skip to content

Commit 2a49fee

Browse files
Merge pull request #15 from metadsl/lambda-example
Add lambda calculus example
2 parents 5913d1e + c2e60b2 commit 2a49fee

File tree

7 files changed

+364
-11
lines changed

7 files changed

+364
-11
lines changed

docs/changelog.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22

33
## Unreleased
44

5+
- Fix bug calling methods on paramterized types (e.g. `Map[i64, i64].empty().insert(i64(0), i64(1))`)
6+
- Fix bug for Unit type (egg name is `Unit` not `unit`)
7+
- Use `@class_` decorator to force subclassing `BaseExpr`
8+
- Workaround extracting definitions until [upstream is fixed](https://github.com/mwillsey/egg-smol/pull/140)
9+
- Rename `Map.map_remove` to `Map.remove`.
10+
- Add lambda calculus example
11+
512
## 0.3.0 (2023-04.26)
613

714
- [Upgrade `egg-smol` from `08a6e8fecdb77e6ba72a1b1d9ff4aff33229912c` to `6f2633a5fa379487fb389b80fc1225866f8b8c1a`.](https://github.com/metadsl/egg-smol-python/pull/14)

python/egg_smol/builtins.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def __and__(self, __t: Map[T, V]) -> Map[T, V]: # type: ignore[empty-body]
193193
...
194194

195195
@BUILTINS.method(egg_fn="map-remove")
196-
def map_remove(self, key: T) -> Map[T, V]: # type: ignore[empty-body]
196+
def remove(self, key: T) -> Map[T, V]: # type: ignore[empty-body]
197197
...
198198

199199

python/egg_smol/declarations.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ def constant_function_decl(type_ref: JustTypeRef) -> FunctionDecl:
191191
Create a function decleartion for a constant function. This is similar to how egg-smol compiles
192192
the `constant` command.
193193
"""
194-
return FunctionDecl(arg_types=(), return_type=type_ref.to_var(), cost=bindings.HIGH_COST)
194+
# Divide high cost by 10 to not overflow the cost field.
195+
return FunctionDecl(arg_types=(), return_type=type_ref.to_var(), cost=int(bindings.HIGH_COST / 10))
195196

196197

197198
# Have two different types of type refs, one that can include vars recursively and one that cannot.
@@ -218,7 +219,7 @@ def to_commands(self, decls: Declarations) -> Iterable[bindings._Command]:
218219
egg_name = decls._type_ref_to_egg_sort[self]
219220
for arg in self.args:
220221
yield from decls._register_sort(arg)
221-
arg_sorts = [cast(bindings._Expr, bindings.Var(decls._type_ref_to_egg_sort[a])) for a in self.args]
222+
arg_sorts = [cast("bindings._Expr", bindings.Var(decls._type_ref_to_egg_sort[a])) for a in self.args]
222223
yield bindings.Sort(egg_name, (self.name, arg_sorts) if arg_sorts else None)
223224

224225
def to_var(self) -> TypeRefWithVars:

python/egg_smol/egraph.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,22 @@
6767
T = TypeVar("T")
6868
TS = TypeVarTuple("TS")
6969
P = ParamSpec("P")
70-
TYPE = TypeVar("TYPE", bound=type)
70+
TYPE = TypeVar("TYPE", bound="type[BaseExpr]")
7171
CALLABLE = TypeVar("CALLABLE", bound=Callable)
7272
EXPR = TypeVar("EXPR", bound="BaseExpr")
7373

7474
# Attributes which are sometimes added to classes by the interpreter or the dataclass decorator, or by ipython.
7575
# We ignore these when inspecting the class.
7676

77-
IGNORED_ATTRIBUTES = {"__module__", "__doc__", "__dict__", "__weakref__", "__orig_bases__", "__annotations__"}
77+
IGNORED_ATTRIBUTES = {
78+
"__module__",
79+
"__doc__",
80+
"__dict__",
81+
"__weakref__",
82+
"__orig_bases__",
83+
"__annotations__",
84+
"__hash__",
85+
}
7886

7987

8088
@dataclass
@@ -137,7 +145,7 @@ def relation(self, name: str, *tps: Unpack[TS], egg_fn: Optional[str] = None) ->
137145
Defines a relation, which is the same as a function which returns unit.
138146
"""
139147
arg_types = tuple(self._resolve_type_annotation(cast(object, tp), [], None) for tp in tps)
140-
fn_decl = FunctionDecl(arg_types, TypeRefWithVars("Unit"))
148+
fn_decl = FunctionDecl(arg_types, TypeRefWithVars("unit"))
141149
commands = self._decls.register_callable(FunctionRef(name), fn_decl, egg_fn)
142150
self._run_program(commands)
143151
return cast(Callable[[Unpack[TS]], Unit], RuntimeFunction(self._decls, name))
@@ -694,7 +702,7 @@ def __eq__(self, other: NoReturn) -> NoReturn: # type: ignore[override, empty-b
694702
BUILTINS = EGraph(_for_builtins=True)
695703

696704

697-
@BUILTINS.class_(egg_sort="unit")
705+
@BUILTINS.class_(egg_sort="Unit")
698706
class Unit(BaseExpr):
699707
"""
700708
The unit type. This is also used to reprsent if a value exists, if it is resolved or not.

python/egg_smol/examples/lambda.py

Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
"""
2+
Lambda Calculus
3+
===============
4+
"""
5+
# mypy: disable-error-code=empty-body
6+
from __future__ import annotations
7+
8+
from typing import Callable, ClassVar
9+
10+
from egg_smol import *
11+
12+
egraph = EGraph()
13+
14+
# TODO: Debug extracting constants
15+
16+
17+
@egraph.class_
18+
class Val(BaseExpr):
19+
"""
20+
A value is a number or a boolean.
21+
"""
22+
23+
TRUE: ClassVar[Val]
24+
FALSE: ClassVar[Val]
25+
26+
def __init__(self, v: i64Like) -> None:
27+
...
28+
29+
30+
@egraph.class_
31+
class Var(BaseExpr):
32+
def __init__(self, v: StringLike) -> None:
33+
...
34+
35+
36+
@egraph.class_
37+
class Term(BaseExpr):
38+
@classmethod
39+
def val(cls, v: Val) -> Term:
40+
...
41+
42+
@classmethod
43+
def var(cls, v: Var) -> Term:
44+
...
45+
46+
def __add__(self, other: Term) -> Term:
47+
...
48+
49+
def __eq__(self, other: Term) -> Term: # type: ignore[override]
50+
...
51+
52+
def __call__(self, other: Term) -> Term:
53+
...
54+
55+
def eval(self) -> Val:
56+
...
57+
58+
def v(self) -> Var:
59+
...
60+
61+
62+
@egraph.function
63+
def lam(x: Var, t: Term) -> Term:
64+
...
65+
66+
67+
@egraph.function
68+
def let_(x: Var, t: Term, b: Term) -> Term:
69+
...
70+
71+
72+
@egraph.function
73+
def fix(x: Var, t: Term) -> Term:
74+
...
75+
76+
77+
@egraph.function
78+
def if_(c: Term, t: Term, f: Term) -> Term:
79+
...
80+
81+
82+
StringSet = Map[Var, i64]
83+
84+
85+
@egraph.function(merge=lambda old, new: old & new)
86+
def freer(t: Term) -> StringSet:
87+
...
88+
89+
90+
(v, v1, v2) = vars_("v v1 v2", Val)
91+
(t, t1, t2, t3, t4) = vars_("t t1 t2 t3 t4", Term)
92+
(x, y) = vars_("x y", Var)
93+
fv, fv1, fv2, fv3 = vars_("fv fv1 fv2 fv3", StringSet)
94+
i1, i2 = vars_("i1 i2", i64)
95+
egraph.register(
96+
# freer
97+
rule(eq(t).to(Term.val(v))).then(set_(freer(t)).to(StringSet.empty())),
98+
rule(eq(t).to(Term.var(x))).then(set_(freer(t)).to(StringSet.empty().insert(x, i64(1)))),
99+
rule(eq(t).to(t1 + t2), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2)).then(set_(freer(t)).to(fv1 | fv2)),
100+
rule(eq(t).to(t1 == t2), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2)).then(set_(freer(t)).to(fv1 | fv2)),
101+
rule(eq(t).to(t1(t2)), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2)).then(set_(freer(t)).to(fv1 | fv2)),
102+
rule(eq(t).to(lam(x, t1)), eq(freer(t1)).to(fv)).then(set_(freer(t)).to(fv.remove(x))),
103+
rule(eq(t).to(let_(x, t1, t2)), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2)).then(
104+
set_(freer(t)).to(fv1.remove(x) | fv2)
105+
),
106+
rule(eq(t).to(fix(x, t1)), eq(freer(t1)).to(fv)).then(set_(freer(t)).to(fv.remove(x))),
107+
rule(eq(t).to(if_(t1, t2, t3)), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2), eq(freer(t3)).to(fv3)).then(
108+
set_(freer(t)).to(fv1 | fv2 | fv3)
109+
),
110+
# eval
111+
rule(eq(t).to(Term.val(v))).then(set_(t.eval()).to(v)),
112+
rule(eq(t).to(t1 + t2), eq(Val(i1)).to(t1.eval()), eq(Val(i2)).to(t2.eval())).then(
113+
union(t.eval()).with_(Val(i1 + i2))
114+
),
115+
rule(eq(t).to(t1 == t2), eq(t1.eval()).to(t2.eval())).then(union(t.eval()).with_(Val.TRUE)),
116+
rule(eq(t).to(t1 == t2), eq(t1.eval()).to(v1), eq(t2.eval()).to(v2), v1 != v2).then(
117+
union(t.eval()).with_(Val.FALSE)
118+
),
119+
rule(eq(v).to(t.eval())).then(union(t).with_(Term.val(v))),
120+
# if
121+
rewrite(if_(Term.val(Val.TRUE), t1, t2)).to(t1),
122+
rewrite(if_(Term.val(Val.FALSE), t1, t2)).to(t2),
123+
# if-elim
124+
# Adds let rules so next one can match on them
125+
rule(eq(t).to(if_(Term.var(x) == t1, t2, t3))).then(let_(x, t1, t2), let_(x, t1, t3)),
126+
rewrite(if_(Term.var(x) == t1, t2, t3)).to(
127+
t3,
128+
eq(let_(x, t1, t2)).to(let_(x, t1, t3)),
129+
),
130+
# add-comm
131+
rewrite(t1 + t2).to(t2 + t1),
132+
# add-assoc
133+
rewrite((t1 + t2) + t3).to(t1 + (t2 + t3)),
134+
# eq-comm
135+
rewrite(t1 == t2).to(t2 == t1),
136+
# Fix
137+
rewrite(fix(x, t)).to(let_(x, fix(x, t), t)),
138+
# beta reduction
139+
rewrite(lam(x, t)(t1)).to(let_(x, t1, t)),
140+
# let-app
141+
rewrite(let_(x, t, t1(t2))).to(let_(x, t, t1)(let_(x, t, t2))),
142+
# let-add
143+
rewrite(let_(x, t, t1 + t2)).to(let_(x, t, t1) + let_(x, t, t2)),
144+
# let-eq
145+
rewrite(let_(x, t, t1 == t2)).to(let_(x, t, t1) == let_(x, t, t2)),
146+
# let-const
147+
rewrite(let_(x, t, Term.val(v))).to(Term.val(v)),
148+
# let-if
149+
rewrite(let_(x, t, if_(t1, t2, t3))).to(if_(let_(x, t, t1), let_(x, t, t2), let_(x, t, t3))),
150+
# let-var-same
151+
rewrite(let_(x, t, Term.var(x))).to(t),
152+
# let-var-diff
153+
rewrite(let_(x, t, Term.var(y))).to(Term.var(y), x != y),
154+
# let-lam-same
155+
rewrite(let_(x, t, lam(x, t1))).to(lam(x, t1)),
156+
# let-lam-diff
157+
rewrite(let_(x, t, lam(y, t1))).to(lam(y, let_(x, t, t1)), x != y, eq(fv).to(freer(t)), fv.not_contains(y)),
158+
rule(eq(t).to(let_(x, t1, lam(y, t2))), x != y, eq(fv).to(freer(t1)), fv.contains(y)).then(
159+
union(t).with_(lam(t.v(), let_(x, t1, let_(y, Term.var(t.v()), t2))))
160+
),
161+
)
162+
163+
result = egraph.relation("result")
164+
165+
166+
def l(fn: Callable[[Term], Term]) -> Term: # noqa
167+
"""
168+
Create a lambda term from a function
169+
"""
170+
# Use first var name from fn
171+
x = fn.__code__.co_varnames[0]
172+
return lam(Var(x), fn(Term.var(Var(x))))
173+
174+
175+
def assert_simplifies(left: BaseExpr, right: BaseExpr) -> None:
176+
"""
177+
Simplify and print
178+
"""
179+
with egraph:
180+
res = egraph.simplify(left, 30)
181+
print(f"{left}{res}")
182+
assert expr_parts(res) == expr_parts(right), f"{res} != {right}"
183+
184+
185+
assert_simplifies((Term.val(Val(1))).eval(), Val(1))
186+
assert_simplifies((Term.val(Val(1)) + Term.val(Val(2))).eval(), Val(3))
187+
188+
189+
# lambda under
190+
assert_simplifies(
191+
l(lambda x: Term.val(Val(4)) + l(lambda y: y)(Term.val(Val(4)))),
192+
l(lambda x: Term.val(Val(8))),
193+
)
194+
# lambda if elim
195+
a = Term.var(Var("a"))
196+
b = Term.var(Var("b"))
197+
with egraph:
198+
e1 = egraph.define("e1", if_(a == b, a + a, a + b))
199+
egraph.run(10)
200+
egraph.check(eq(e1).to(a + b))
201+
202+
# lambda let simple
203+
x = Var("x")
204+
y = Var("y")
205+
assert_simplifies(
206+
let_(x, Term.val(Val(0)), let_(y, Term.val(Val(1)), Term.var(x) + Term.var(y))),
207+
Term.val(Val(1)),
208+
)
209+
# lambda capture
210+
assert_simplifies(
211+
let_(x, Term.val(Val(1)), l(lambda x: x)),
212+
l(lambda x: x),
213+
)
214+
# lambda capture free
215+
with egraph:
216+
e5 = egraph.define("e5", let_(y, Term.var(x) + Term.var(x), l(lambda x: Term.var(y))))
217+
egraph.run(10)
218+
egraph.check(freer(l(lambda x: Term.var(y))).contains(y))
219+
egraph.check_fail(eq(e5).to(l(lambda x: x + x)))
220+
221+
# lambda_closure_not_seven
222+
with egraph:
223+
e6 = egraph.define(
224+
"e6",
225+
let_(
226+
Var("five"),
227+
Term.val(Val(5)),
228+
let_(
229+
Var("add-five"),
230+
l(lambda x: x + Term.var(Var("five"))),
231+
let_(Var("five"), Term.val(Val(6)), Term.var(Var("add-five"))(Term.val(Val(1)))),
232+
),
233+
),
234+
)
235+
egraph.run(10)
236+
egraph.check_fail(eq(e6).to(Term.val(Val(7))))
237+
egraph.check(eq(e6).to(Term.val(Val(6))))
238+
239+
240+
# lambda_compose
241+
with egraph:
242+
compose = Var("compose")
243+
add1 = Var("add1")
244+
e7 = egraph.define(
245+
"e7",
246+
let_(
247+
compose,
248+
l(
249+
lambda f: l(
250+
lambda g: l(
251+
lambda x: f(g(x)),
252+
),
253+
),
254+
),
255+
let_(
256+
add1,
257+
l(lambda y: y + Term.val(Val(1))),
258+
Term.var(compose)(Term.var(add1))(Term.var(add1)),
259+
),
260+
),
261+
)
262+
egraph.run(20)
263+
egraph.register(
264+
rule(
265+
eq(t1).to(l(lambda x: Term.val(Val(1)) + l(lambda y: Term.val(Val(1)) + y)(x))),
266+
eq(t2).to(l(lambda x: x + Term.val(Val(2)))),
267+
).then(result())
268+
)
269+
egraph.run(1)
270+
egraph.check(result())
271+
272+
273+
# lambda_if_simple
274+
assert_simplifies(if_(Term.val(Val(1)) == Term.val(Val(1)), Term.val(Val(7)), Term.val(Val(9))), Term.val(Val(7)))
275+
276+
277+
# # lambda_compose_many
278+
assert_simplifies(
279+
let_(
280+
compose,
281+
l(lambda f: l(lambda g: l(lambda x: f(g(x))))),
282+
let_(
283+
add1,
284+
l(lambda y: y + Term.val(Val(1))),
285+
Term.var(compose)(Term.var(add1))(
286+
Term.var(compose)(Term.var(add1))(
287+
Term.var(compose)(Term.var(add1))(
288+
Term.var(compose)(Term.var(add1))(
289+
Term.var(compose)(Term.var(add1))(Term.var(compose)(Term.var(add1))(Term.var(add1)))
290+
)
291+
)
292+
)
293+
),
294+
),
295+
),
296+
l(lambda x: x + Term.val(Val(7))),
297+
)
298+
299+
# lambda_if
300+
zeroone = Var("zeroone")
301+
assert_simplifies(
302+
let_(
303+
zeroone,
304+
l(lambda x: if_(x == Term.val(Val(0)), Term.val(Val(0)), Term.val(Val(1)))),
305+
Term.var(zeroone)(Term.val(Val(0))) + Term.var(zeroone)(Term.val(Val(10))),
306+
),
307+
Term.val(Val(1)),
308+
)

0 commit comments

Comments
 (0)