Skip to content

Commit 8d4ae61

Browse files
Fix tests
1 parent f8e07ef commit 8d4ae61

File tree

3 files changed

+35
-46
lines changed

3 files changed

+35
-46
lines changed

python/egglog/declarations.py

Lines changed: 31 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
"FunctionCallableRef",
2929
"CallableRef",
3030
"ConstantRef",
31-
# "constant_function_decl",
3231
"FunctionDecl",
3332
"VarDecl",
3433
"LitType",
@@ -45,7 +44,6 @@
4544
"BiRewrite",
4645
"Eq",
4746
"ExprFact",
48-
# "fact_decl_to_egg",
4947
"Rule",
5048
"Let",
5149
"Set",
@@ -56,7 +54,6 @@
5654
"Schedule",
5755
"Sequence",
5856
"Run",
59-
# "action_decl_to_egg",
6057
]
6158
# Special methods which we might want to use as functions
6259
# Mapping to the operator they represent for pretty printing them
@@ -240,27 +237,30 @@ def register_class(self, name: str, n_type_vars: int, egg_sort: Optional[str]) -
240237
raise ValueError(f"Class {name} already registered")
241238
decl = ClassDecl(n_type_vars=n_type_vars)
242239
self._decl._classes[name] = decl
243-
return self.register_sort(JustTypeRef(name), egg_sort)
240+
_egg_sort, cmds = self.register_sort(JustTypeRef(name), egg_sort)
241+
return cmds
244242

245-
def register_sort(self, ref: JustTypeRef, egg_name: Optional[str] = None) -> Iterable[bindings._Command]:
243+
def register_sort(
244+
self, ref: JustTypeRef, egg_name: Optional[str] = None
245+
) -> tuple[str, Iterable[bindings._Command]]:
246246
"""
247247
Register a sort with the given name. If no name is given, one is generated.
248248
249249
If this is a type called with generic args, register the generic args as well.
250250
"""
251251
# If the sort is already registered, do nothing
252252
try:
253-
self.get_egg_sort(ref)
253+
egg_sort = self.get_egg_sort(ref)
254254
except KeyError:
255255
pass
256256
else:
257-
return []
257+
return (egg_sort, [])
258258
egg_name = egg_name or ref.generate_egg_name()
259259
if egg_name in self._decl._egg_sort_to_type_ref:
260260
raise ValueError(f"Sort {egg_name} is already registered.")
261261
self._decl._egg_sort_to_type_ref[egg_name] = ref
262262
self._decl._type_ref_to_egg_sort[ref] = egg_name
263-
return ref.to_commands(self)
263+
return egg_name, ref.to_commands(self)
264264

265265
def register_function_callable(
266266
self,
@@ -284,20 +284,12 @@ def register_function_callable(
284284
def register_constant_callable(
285285
self, ref: ConstantCallableRef, type_ref: JustTypeRef, egg_name: Optional[str]
286286
) -> Iterable[bindings._Command]:
287-
egg_name = egg_name or ref.generate_egg_name()
288-
self._decl.register_callable_ref(ref, ref.generate_egg_name())
287+
egg_function = ref.generate_egg_name()
288+
self._decl.register_callable_ref(ref, egg_function)
289289
self._decl.set_constant_type(ref, type_ref)
290-
yield from type_ref.to_commands(self)
291-
yield bindings.Declare(egg_name, egg_name)
292-
293-
294-
# def constant_function_decl(type_ref: JustTypeRef) -> FunctionDecl:
295-
# """
296-
# Create a function decleartion for a constant function. This is similar to how egglog compiles
297-
# the `constant` command.
298-
# """
299-
# # Divide high cost by 10 to not overflow the cost field.
300-
# return FunctionDecl(arg_types=(), return_type=type_ref.to_var(), cost=int(bindings.HIGH_COST / 10))
290+
# Create a function decleartion for a constant function. This is similar to how egglog compiles
291+
# the `declare` command.
292+
return FunctionDecl((), type_ref.to_var()).to_commands(self, egg_name or ref.generate_egg_name())
301293

302294

303295
# Have two different types of type refs, one that can include vars recursively and one that cannot.
@@ -322,9 +314,11 @@ def to_commands(self, mod_decls: ModuleDeclarations) -> Iterable[bindings._Comma
322314
Returns commands to register this as a sort, as well as for any of its arguments.
323315
"""
324316
egg_name = mod_decls.get_egg_sort(self)
317+
arg_sorts: list[bindings._Expr] = []
325318
for arg in self.args:
326-
yield from mod_decls.register_sort(arg)
327-
arg_sorts: list[bindings._Expr] = [bindings.Var(mod_decls.get_egg_sort(a)) for a in self.args]
319+
egg_sort, cmds = mod_decls.register_sort(arg)
320+
arg_sorts.append(bindings.Var(egg_sort))
321+
yield from cmds
328322
yield bindings.Sort(egg_name, (self.name, arg_sorts) if arg_sorts else None)
329323

330324
def to_var(self) -> TypeRefWithVars:
@@ -426,36 +420,31 @@ class FunctionDecl:
426420
arg_types: tuple[TypeOrVarRef, ...]
427421
return_type: TypeOrVarRef
428422
var_arg_type: Optional[TypeOrVarRef] = None
429-
# cost: Optional[int] = None
430-
# default: Optional[ExprDecl] = None
431-
# merge: Optional[ExprDecl] = None
432-
# merge_action: tuple[ActionDecl, ...] = ()
433423

434424
def to_commands(
435425
self,
436426
mod_decls: ModuleDeclarations,
437427
egg_name: str,
438-
cost: Optional[int],
439-
default: Optional[ExprDecl],
440-
merge: Optional[ExprDecl],
441-
merge_action: Iterable[Action],
428+
cost: Optional[int] = None,
429+
default: Optional[ExprDecl] = None,
430+
merge: Optional[ExprDecl] = None,
431+
merge_action: Iterable[Action] = (),
442432
) -> Iterable[bindings._Command]:
443433
if self.var_arg_type is not None:
444434
raise NotImplementedError("egglog does not support variable arguments yet.")
445-
just_arg_types = [a.to_just() for a in self.arg_types]
446-
for a in just_arg_types:
447-
yield from mod_decls.register_sort(a)
448-
just_return_type = self.return_type.to_just()
449-
yield from mod_decls.register_sort(just_return_type)
435+
arg_sorts: list[str] = []
436+
for a in self.arg_types:
437+
# Remove all vars from the type refs, raising an errory if we find one,
438+
# since we cannot create egg functions with vars
439+
arg_sort, cmds = mod_decls.register_sort(a.to_just())
440+
yield from cmds
441+
arg_sorts.append(arg_sort)
442+
return_sort, cmds = mod_decls.register_sort(self.return_type.to_just())
443+
yield from cmds
450444

451445
egg_fn_decl = bindings.FunctionDecl(
452446
egg_name,
453-
# Remove all vars from the type refs, raising an errory if we find one,
454-
# since we cannot create egg functions with vars
455-
bindings.Schema(
456-
[mod_decls.get_egg_sort(a) for a in just_arg_types],
457-
mod_decls.get_egg_sort(just_return_type),
458-
),
447+
bindings.Schema(arg_sorts, return_sort),
459448
default.to_egg(mod_decls) if default else None,
460449
merge.to_egg(mod_decls) if merge else None,
461450
[a._to_egg_action(mod_decls) for a in merge_action],

python/egglog/egraph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,6 @@ def _register_constant(
523523
Register a constant, returning its typeref().
524524
"""
525525
type_ref = self._resolve_type_annotation(tp, [], cls_type_and_name).to_just()
526-
# fn_decl = constant_function_decl(type_ref)
527526
self._process_commands(self._mod_decls.register_constant_callable(ref, type_ref, egg_name))
528527
return type_ref
529528

python/tests/test_high_level.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@ def __init__(self, v: i64) -> None:
161161
assert expr_parts(egraph.simplify(Numeric.ONE, 10)) == expr_parts(Numeric.ONE)
162162

163163
egraph.register(union(Numeric.ONE).with_(Numeric(i64(1))))
164-
165-
assert expr_parts(egraph.simplify(Numeric.ONE, 10)) == expr_parts(Numeric(i64(1)))
164+
egraph.run(10)
165+
egraph.check(eq(Numeric.ONE).to(Numeric(i64(1))))
166166

167167

168168
def test_extract_constant_twice():
@@ -209,12 +209,13 @@ class Numeric(BaseExpr):
209209

210210
@m2.class_
211211
class OtherNumeric(BaseExpr):
212+
@m2.method(cost=10)
212213
def __init__(self, v: i64Like) -> None:
213214
...
214215

215216
egraph = EGraph(deps=[m, m2])
216217

217-
@egraph.function(cost=0)
218+
@egraph.function
218219
def from_numeric(n: Numeric) -> OtherNumeric: # type: ignore[empty-body]
219220
...
220221

0 commit comments

Comments
 (0)