@@ -87,9 +87,7 @@ class Declarations:
87
87
# Bidirectional mapping between egg function names and python callable references.
88
88
# Note that there are possibly mutliple callable references for a single egg function name, like `+`
89
89
# for both int and rational classes.
90
- egg_fn_to_callable_refs : defaultdict [str , set [CallableRef ]] = field (
91
- default_factory = lambda : defaultdict (set )
92
- )
90
+ egg_fn_to_callable_refs : defaultdict [str , set [CallableRef ]] = field (default_factory = lambda : defaultdict (set ))
93
91
callable_ref_to_egg_fn : dict [CallableRef , str ] = field (default_factory = dict )
94
92
95
93
# Bidirectional mapping between egg sort names and python type references.
@@ -119,9 +117,7 @@ def get_function_decl(self, ref: CallableRef) -> FunctionDecl:
119
117
return self .constants [ref .name ].to_function_decl ()
120
118
assert_never (ref )
121
119
122
- def register_sort (
123
- self , type_ref : JustTypeRef , egg_name : Optional [str ] = None
124
- ) -> str :
120
+ def register_sort (self , type_ref : JustTypeRef , egg_name : Optional [str ] = None ) -> str :
125
121
egg_name = egg_name or type_ref .generate_egg_name ()
126
122
if egg_name in self .egg_sort_to_type_ref :
127
123
raise ValueError (f"Sort { egg_name } is already registered." )
@@ -154,10 +150,7 @@ def to_egg(self, decls: Declarations, egraph: bindings.EGraph) -> str:
154
150
raise ValueError (f"Type { self .name } is not registered." )
155
151
# If this is a type with arguments and it is not registered, then we need to register it
156
152
egg_name = decls .register_sort (self )
157
- arg_sorts = [
158
- cast (bindings ._Expr , bindings .Var (a .to_egg (decls , egraph )))
159
- for a in self .args
160
- ]
153
+ arg_sorts = [cast (bindings ._Expr , bindings .Var (a .to_egg (decls , egraph ))) for a in self .args ]
161
154
egraph .declare_sort (egg_name , (self .name , arg_sorts ))
162
155
return egg_name
163
156
@@ -265,9 +258,7 @@ class FunctionDecl:
265
258
default : Optional [ExprDecl ] = None
266
259
merge : Optional [ExprDecl ] = None
267
260
268
- def to_egg (
269
- self , decls : Declarations , egraph : bindings .EGraph , ref : CallableRef
270
- ) -> bindings .FunctionDecl :
261
+ def to_egg (self , decls : Declarations , egraph : bindings .EGraph , ref : CallableRef ) -> bindings .FunctionDecl :
271
262
return bindings .FunctionDecl (
272
263
decls .callable_ref_to_egg_fn [ref ],
273
264
# Remove all vars from the type refs, raising an errory if we find one,
@@ -288,9 +279,7 @@ class VarDecl:
288
279
289
280
@classmethod
290
281
def from_egg (cls , var : bindings .Var ) -> tuple [JustTypeRef , LitDecl ]:
291
- raise NotImplementedError (
292
- "Cannot turn var into egg type because typing unknown."
293
- )
282
+ raise NotImplementedError ("Cannot turn var into egg type because typing unknown." )
294
283
295
284
def to_egg (self , _decls : Declarations ) -> bindings .Var :
296
285
return bindings .Var (self .name )
@@ -349,14 +338,10 @@ class CallDecl:
349
338
350
339
def __post_init__ (self ):
351
340
if self .bound_tp_params and not isinstance (self .callable , ClassMethodRef ):
352
- raise ValueError (
353
- "Cannot bind type parameters to a non-class method callable."
354
- )
341
+ raise ValueError ("Cannot bind type parameters to a non-class method callable." )
355
342
356
343
@classmethod
357
- def from_egg (
358
- cls , decls : Declarations , call : bindings .Call
359
- ) -> tuple [JustTypeRef , CallDecl ]:
344
+ def from_egg (cls , decls : Declarations , call : bindings .Call ) -> tuple [JustTypeRef , CallDecl ]:
360
345
from .type_constraint_solver import TypeConstraintSolver
361
346
362
347
results = [tp_and_expr_decl_from_egg (decls , a ) for a in call .args ]
@@ -367,9 +352,7 @@ def from_egg(
367
352
for callable_ref in decls .egg_fn_to_callable_refs [call .name ]:
368
353
tcs = TypeConstraintSolver ()
369
354
fn_decl = decls .get_function_decl (callable_ref )
370
- return_tp = tcs .infer_return_type (
371
- fn_decl .arg_types , fn_decl .return_type , arg_types
372
- )
355
+ return_tp = tcs .infer_return_type (fn_decl .arg_types , fn_decl .return_type , arg_types )
373
356
return return_tp , cls (callable_ref , arg_decls )
374
357
raise ValueError (f"Could not find callable ref for call { call } " )
375
358
@@ -421,34 +404,12 @@ def test_expr_pretty():
421
404
assert LitDecl ("foo" ).pretty () == 'String("foo")'
422
405
assert LitDecl (None ).pretty () == "unit()"
423
406
assert CallDecl (FunctionRef ("foo" ), (VarDecl ("x" ),)).pretty () == "foo(x)"
424
- assert (
425
- CallDecl (
426
- FunctionRef ("foo" ), (VarDecl ("x" ), VarDecl ("y" ), VarDecl ("z" ))
427
- ).pretty ()
428
- == "foo(x, y, z)"
429
- )
430
- assert (
431
- CallDecl (MethodRef ("foo" , "__add__" ), (VarDecl ("x" ), VarDecl ("y" ))).pretty ()
432
- == "x + y"
433
- )
434
- assert (
435
- CallDecl (MethodRef ("foo" , "__getitem__" ), (VarDecl ("x" ), VarDecl ("y" ))).pretty ()
436
- == "x[y]"
437
- )
438
- assert (
439
- CallDecl (
440
- ClassMethodRef ("foo" , "__init__" ), (VarDecl ("x" ), VarDecl ("y" ))
441
- ).pretty ()
442
- == "foo(x, y)"
443
- )
444
- assert (
445
- CallDecl (ClassMethodRef ("foo" , "bar" ), (VarDecl ("x" ), VarDecl ("y" ))).pretty ()
446
- == "foo.bar(x, y)"
447
- )
448
- assert (
449
- CallDecl (MethodRef ("foo" , "__call__" ), (VarDecl ("x" ), VarDecl ("y" ))).pretty ()
450
- == "x(y)"
451
- )
407
+ assert CallDecl (FunctionRef ("foo" ), (VarDecl ("x" ), VarDecl ("y" ), VarDecl ("z" ))).pretty () == "foo(x, y, z)"
408
+ assert CallDecl (MethodRef ("foo" , "__add__" ), (VarDecl ("x" ), VarDecl ("y" ))).pretty () == "x + y"
409
+ assert CallDecl (MethodRef ("foo" , "__getitem__" ), (VarDecl ("x" ), VarDecl ("y" ))).pretty () == "x[y]"
410
+ assert CallDecl (ClassMethodRef ("foo" , "__init__" ), (VarDecl ("x" ), VarDecl ("y" ))).pretty () == "foo(x, y)"
411
+ assert CallDecl (ClassMethodRef ("foo" , "bar" ), (VarDecl ("x" ), VarDecl ("y" ))).pretty () == "foo.bar(x, y)"
412
+ assert CallDecl (MethodRef ("foo" , "__call__" ), (VarDecl ("x" ), VarDecl ("y" ))).pretty () == "x(y)"
452
413
assert (
453
414
CallDecl (
454
415
ClassMethodRef ("Map" , "__init__" ),
@@ -462,9 +423,7 @@ def test_expr_pretty():
462
423
ExprDecl = Union [VarDecl , LitDecl , CallDecl ]
463
424
464
425
465
- def tp_and_expr_decl_from_egg (
466
- decls : Declarations , expr : bindings ._Expr
467
- ) -> tuple [JustTypeRef , ExprDecl ]:
426
+ def tp_and_expr_decl_from_egg (decls : Declarations , expr : bindings ._Expr ) -> tuple [JustTypeRef , ExprDecl ]:
468
427
if isinstance (expr , bindings .Var ):
469
428
return VarDecl .from_egg (expr )
470
429
if isinstance (expr , bindings .Lit ):
@@ -557,9 +516,7 @@ class DeleteDecl:
557
516
call : CallDecl
558
517
559
518
def to_egg (self , decls : Declarations ) -> bindings .Delete :
560
- return bindings .Delete (
561
- self .call .callable .to_egg (decls ), [a .to_egg (decls ) for a in self .call .args ]
562
- )
519
+ return bindings .Delete (self .call .callable .to_egg (decls ), [a .to_egg (decls ) for a in self .call .args ])
563
520
564
521
565
522
@dataclass (frozen = True )
0 commit comments