28
28
"FunctionCallableRef" ,
29
29
"CallableRef" ,
30
30
"ConstantRef" ,
31
- # "constant_function_decl",
32
31
"FunctionDecl" ,
33
32
"VarDecl" ,
34
33
"LitType" ,
45
44
"BiRewrite" ,
46
45
"Eq" ,
47
46
"ExprFact" ,
48
- # "fact_decl_to_egg",
49
47
"Rule" ,
50
48
"Let" ,
51
49
"Set" ,
56
54
"Schedule" ,
57
55
"Sequence" ,
58
56
"Run" ,
59
- # "action_decl_to_egg",
60
57
]
61
58
# Special methods which we might want to use as functions
62
59
# Mapping to the operator they represent for pretty printing them
@@ -240,27 +237,30 @@ def register_class(self, name: str, n_type_vars: int, egg_sort: Optional[str]) -
240
237
raise ValueError (f"Class { name } already registered" )
241
238
decl = ClassDecl (n_type_vars = n_type_vars )
242
239
self ._decl ._classes [name ] = decl
243
- return self .register_sort (JustTypeRef (name ), egg_sort )
240
+ _egg_sort , cmds = self .register_sort (JustTypeRef (name ), egg_sort )
241
+ return cmds
244
242
245
- def register_sort (self , ref : JustTypeRef , egg_name : Optional [str ] = None ) -> Iterable [bindings ._Command ]:
243
+ def register_sort (
244
+ self , ref : JustTypeRef , egg_name : Optional [str ] = None
245
+ ) -> tuple [str , Iterable [bindings ._Command ]]:
246
246
"""
247
247
Register a sort with the given name. If no name is given, one is generated.
248
248
249
249
If this is a type called with generic args, register the generic args as well.
250
250
"""
251
251
# If the sort is already registered, do nothing
252
252
try :
253
- self .get_egg_sort (ref )
253
+ egg_sort = self .get_egg_sort (ref )
254
254
except KeyError :
255
255
pass
256
256
else :
257
- return []
257
+ return ( egg_sort , [])
258
258
egg_name = egg_name or ref .generate_egg_name ()
259
259
if egg_name in self ._decl ._egg_sort_to_type_ref :
260
260
raise ValueError (f"Sort { egg_name } is already registered." )
261
261
self ._decl ._egg_sort_to_type_ref [egg_name ] = ref
262
262
self ._decl ._type_ref_to_egg_sort [ref ] = egg_name
263
- return ref .to_commands (self )
263
+ return egg_name , ref .to_commands (self )
264
264
265
265
def register_function_callable (
266
266
self ,
@@ -284,20 +284,12 @@ def register_function_callable(
284
284
def register_constant_callable (
285
285
self , ref : ConstantCallableRef , type_ref : JustTypeRef , egg_name : Optional [str ]
286
286
) -> Iterable [bindings ._Command ]:
287
- egg_name = egg_name or ref .generate_egg_name ()
288
- self ._decl .register_callable_ref (ref , ref . generate_egg_name () )
287
+ egg_function = ref .generate_egg_name ()
288
+ self ._decl .register_callable_ref (ref , egg_function )
289
289
self ._decl .set_constant_type (ref , type_ref )
290
- yield from type_ref .to_commands (self )
291
- yield bindings .Declare (egg_name , egg_name )
292
-
293
-
294
- # def constant_function_decl(type_ref: JustTypeRef) -> FunctionDecl:
295
- # """
296
- # Create a function decleartion for a constant function. This is similar to how egglog compiles
297
- # the `constant` command.
298
- # """
299
- # # Divide high cost by 10 to not overflow the cost field.
300
- # return FunctionDecl(arg_types=(), return_type=type_ref.to_var(), cost=int(bindings.HIGH_COST / 10))
290
+ # Create a function decleartion for a constant function. This is similar to how egglog compiles
291
+ # the `declare` command.
292
+ return FunctionDecl ((), type_ref .to_var ()).to_commands (self , egg_name or ref .generate_egg_name ())
301
293
302
294
303
295
# Have two different types of type refs, one that can include vars recursively and one that cannot.
@@ -322,9 +314,11 @@ def to_commands(self, mod_decls: ModuleDeclarations) -> Iterable[bindings._Comma
322
314
Returns commands to register this as a sort, as well as for any of its arguments.
323
315
"""
324
316
egg_name = mod_decls .get_egg_sort (self )
317
+ arg_sorts : list [bindings ._Expr ] = []
325
318
for arg in self .args :
326
- yield from mod_decls .register_sort (arg )
327
- arg_sorts : list [bindings ._Expr ] = [bindings .Var (mod_decls .get_egg_sort (a )) for a in self .args ]
319
+ egg_sort , cmds = mod_decls .register_sort (arg )
320
+ arg_sorts .append (bindings .Var (egg_sort ))
321
+ yield from cmds
328
322
yield bindings .Sort (egg_name , (self .name , arg_sorts ) if arg_sorts else None )
329
323
330
324
def to_var (self ) -> TypeRefWithVars :
@@ -426,36 +420,31 @@ class FunctionDecl:
426
420
arg_types : tuple [TypeOrVarRef , ...]
427
421
return_type : TypeOrVarRef
428
422
var_arg_type : Optional [TypeOrVarRef ] = None
429
- # cost: Optional[int] = None
430
- # default: Optional[ExprDecl] = None
431
- # merge: Optional[ExprDecl] = None
432
- # merge_action: tuple[ActionDecl, ...] = ()
433
423
434
424
def to_commands (
435
425
self ,
436
426
mod_decls : ModuleDeclarations ,
437
427
egg_name : str ,
438
- cost : Optional [int ],
439
- default : Optional [ExprDecl ],
440
- merge : Optional [ExprDecl ],
441
- merge_action : Iterable [Action ],
428
+ cost : Optional [int ] = None ,
429
+ default : Optional [ExprDecl ] = None ,
430
+ merge : Optional [ExprDecl ] = None ,
431
+ merge_action : Iterable [Action ] = () ,
442
432
) -> Iterable [bindings ._Command ]:
443
433
if self .var_arg_type is not None :
444
434
raise NotImplementedError ("egglog does not support variable arguments yet." )
445
- just_arg_types = [a .to_just () for a in self .arg_types ]
446
- for a in just_arg_types :
447
- yield from mod_decls .register_sort (a )
448
- just_return_type = self .return_type .to_just ()
449
- yield from mod_decls .register_sort (just_return_type )
435
+ arg_sorts : list [str ] = []
436
+ for a in self .arg_types :
437
+ # Remove all vars from the type refs, raising an errory if we find one,
438
+ # since we cannot create egg functions with vars
439
+ arg_sort , cmds = mod_decls .register_sort (a .to_just ())
440
+ yield from cmds
441
+ arg_sorts .append (arg_sort )
442
+ return_sort , cmds = mod_decls .register_sort (self .return_type .to_just ())
443
+ yield from cmds
450
444
451
445
egg_fn_decl = bindings .FunctionDecl (
452
446
egg_name ,
453
- # Remove all vars from the type refs, raising an errory if we find one,
454
- # since we cannot create egg functions with vars
455
- bindings .Schema (
456
- [mod_decls .get_egg_sort (a ) for a in just_arg_types ],
457
- mod_decls .get_egg_sort (just_return_type ),
458
- ),
447
+ bindings .Schema (arg_sorts , return_sort ),
459
448
default .to_egg (mod_decls ) if default else None ,
460
449
merge .to_egg (mod_decls ) if merge else None ,
461
450
[a ._to_egg_action (mod_decls ) for a in merge_action ],
0 commit comments