Skip to content

Commit b5b2d49

Browse files
Fix lambda if
1 parent 3f25d52 commit b5b2d49

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

python/egg_smol/examples/lambda.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,24 @@ def freer(t: Term) -> StringSet:
160160
),
161161
)
162162

163+
result = egraph.relation("result")
164+
165+
166+
def l(fn: Callable[[Term], Term]) -> Term: # noqa
167+
"""
168+
Create a lambda term from a function
169+
"""
170+
# Use first var name from fn
171+
x = fn.__code__.co_varnames[0]
172+
return lam(Var(x), fn(Term.var(Var(x))))
173+
163174

164175
def assert_simplifies(left: BaseExpr, right: BaseExpr) -> None:
165176
"""
166177
Simplify and print
167178
"""
168-
res = egraph.simplify(left, 30)
179+
with egraph:
180+
res = egraph.simplify(left, 30)
169181
print(f"{left}{res}")
170182
assert expr_parts(res) == expr_parts(right), f"{res} != {right}"
171183

@@ -174,18 +186,6 @@ def assert_simplifies(left: BaseExpr, right: BaseExpr) -> None:
174186
assert_simplifies((Term.val(Val(1)) + Term.val(Val(2))).eval(), Val(3))
175187

176188

177-
result = egraph.relation("result")
178-
179-
180-
def l(fn: Callable[[Term], Term]) -> Term: # noqa
181-
"""
182-
Create a lambda term from a function
183-
"""
184-
# Use first var name from fn
185-
x = fn.__code__.co_varnames[0]
186-
return lam(Var(x), fn(Term.var(Var(x))))
187-
188-
189189
# lambda under
190190
assert_simplifies(
191191
l(lambda x: Term.val(Val(4)) + l(lambda y: y)(Term.val(Val(4)))),
@@ -270,7 +270,7 @@ def l(fn: Callable[[Term], Term]) -> Term: # noqa
270270
assert_simplifies(if_(Term.val(Val(1)) == Term.val(Val(1)), Term.val(Val(7)), Term.val(Val(9))), Term.val(Val(7)))
271271

272272

273-
# lambda_compose_many
273+
# # lambda_compose_many
274274
assert_simplifies(
275275
let_(
276276
compose,
@@ -298,7 +298,7 @@ def l(fn: Callable[[Term], Term]) -> Term: # noqa
298298
let_(
299299
zeroone,
300300
l(lambda x: if_(x == Term.val(Val(0)), Term.val(Val(0)), Term.val(Val(1)))),
301-
Term.val(Val(0)) + Term.var(zeroone)(Term.val(Val(0))) + Term.var(zeroone)(Term.val(Val(10))),
301+
Term.var(zeroone)(Term.val(Val(0))) + Term.var(zeroone)(Term.val(Val(10))),
302302
),
303303
Term.val(Val(1)),
304304
)

0 commit comments

Comments
 (0)