@@ -185,6 +185,9 @@ def register_callable_ref(self, ref: CallableRef, egg_name: str) -> None:
185
185
def get_egg_fn (self , ref : CallableRef ) -> str :
186
186
return self ._callable_ref_to_egg_fn [ref ]
187
187
188
+ def get_egg_sort (self , ref : JustTypeRef ) -> str :
189
+ return self ._type_ref_to_egg_sort [ref ]
190
+
188
191
189
192
def constant_function_decl (type_ref : JustTypeRef ) -> FunctionDecl :
190
193
"""
@@ -216,10 +219,10 @@ def to_commands(self, decls: Declarations) -> Iterable[bindings._Command]:
216
219
"""
217
220
Register this type with the egg solver.
218
221
"""
219
- egg_name = decls ._type_ref_to_egg_sort [ self ]
222
+ egg_name = decls .get_egg_sort ( self )
220
223
for arg in self .args :
221
224
yield from decls ._register_sort (arg )
222
- arg_sorts = [cast ("bindings._Expr" , bindings .Var (decls ._type_ref_to_egg_sort [ a ] )) for a in self .args ]
225
+ arg_sorts = [cast ("bindings._Expr" , bindings .Var (decls .get_egg_sort ( a ) )) for a in self .args ]
223
226
yield bindings .Sort (egg_name , (self .name , arg_sorts ) if arg_sorts else None )
224
227
225
228
def to_var (self ) -> TypeRefWithVars :
@@ -280,7 +283,7 @@ class ClassMethodRef:
280
283
method_name : str
281
284
282
285
def to_egg (self , decls : Declarations ) -> str :
283
- return decls ._callable_ref_to_egg_fn [ self ]
286
+ return decls .get_egg_fn ( self )
284
287
285
288
def generate_egg_name (self ) -> str :
286
289
return f"{ self .class_name } .{ self .method_name } "
@@ -331,8 +334,8 @@ def to_commands(self, decls: Declarations, egg_name: str) -> Iterable[bindings._
331
334
# Remove all vars from the type refs, raising an errory if we find one,
332
335
# since we cannot create egg functions with vars
333
336
bindings .Schema (
334
- [decls ._type_ref_to_egg_sort [ a ] for a in just_arg_types ],
335
- decls ._type_ref_to_egg_sort [ just_return_type ] ,
337
+ [decls .get_egg_sort ( a ) for a in just_arg_types ],
338
+ decls .get_egg_sort ( just_return_type ) ,
336
339
),
337
340
self .default .to_egg (decls ) if self .default else None ,
338
341
self .merge .to_egg (decls ) if self .merge else None ,
@@ -434,7 +437,7 @@ def from_egg(cls, decls: Declarations, call: bindings.Call) -> tuple[JustTypeRef
434
437
435
438
def to_egg (self , decls : Declarations ) -> bindings .Call :
436
439
"""Convert a Call to an egg Call."""
437
- egg_fn = decls ._callable_ref_to_egg_fn [ self .callable ]
440
+ egg_fn = decls .get_egg_fn ( self .callable )
438
441
return bindings .Call (egg_fn , [a .to_egg (decls ) for a in self .args ])
439
442
440
443
def pretty (self , parens = True , ** kwargs ) -> str :
@@ -578,7 +581,7 @@ class SetDecl:
578
581
579
582
def to_egg (self , decls : Declarations ) -> bindings .Set :
580
583
return bindings .Set (
581
- decls ._callable_ref_to_egg_fn [ self .call .callable ] ,
584
+ decls .get_egg_fn ( self .call .callable ) ,
582
585
[a .to_egg (decls ) for a in self .call .args ],
583
586
self .rhs .to_egg (decls ),
584
587
)
@@ -589,9 +592,7 @@ class DeleteDecl:
589
592
call : CallDecl
590
593
591
594
def to_egg (self , decls : Declarations ) -> bindings .Delete :
592
- return bindings .Delete (
593
- decls ._callable_ref_to_egg_fn [self .call .callable ], [a .to_egg (decls ) for a in self .call .args ]
594
- )
595
+ return bindings .Delete (decls .get_egg_fn (self .call .callable ), [a .to_egg (decls ) for a in self .call .args ])
595
596
596
597
597
598
@dataclass (frozen = True )
0 commit comments