@@ -231,6 +231,16 @@ def get_class_decl(self, name: str) -> ClassDecl:
231
231
pass
232
232
raise KeyError (f"Class { name } not found" )
233
233
234
+ def get_registered_class_args (self , cls_name : str ) -> tuple [JustTypeRef , ...]:
235
+ """
236
+ Given a class name, returns the first typevar regsisted with args of that class.
237
+ """
238
+ for decl in self .all_decls :
239
+ for tp in decl ._type_ref_to_egg_sort .keys ():
240
+ if tp .name == cls_name and tp .args :
241
+ return tp .args
242
+ return ()
243
+
234
244
def register_class (self , name : str , n_type_vars : int , egg_sort : Optional [str ]) -> Iterable [bindings ._Command ]:
235
245
# Register class first
236
246
if name in self ._decl ._classes :
@@ -538,14 +548,10 @@ def from_egg(cls, mod_decls: ModuleDeclarations, call: bindings.Call) -> TypedEx
538
548
for callable_ref in mod_decls .get_callable_refs (call .name ):
539
549
# If this is a classmethod, we might need the type params that were bound for this type
540
550
# egglog currently only allows one instantiated type of any generic sort to be used in any program
541
- # So we just lookup what args were registered for thsi sort
551
+ # So we just lookup what args were registered for this sort
542
552
if isinstance (callable_ref , ClassMethodRef ):
543
- for registered_tp in mod_decls ._decl ._type_ref_to_egg_sort .keys ():
544
- if registered_tp .name == callable_ref .class_name :
545
- tcs = TypeConstraintSolver .from_type_parameters (registered_tp .args )
546
- break
547
- else :
548
- raise ValueError (f"Could not find type parameters for class { callable_ref .class_name } " )
553
+ cls_args = mod_decls .get_registered_class_args (callable_ref .class_name )
554
+ tcs = TypeConstraintSolver .from_type_parameters (cls_args )
549
555
else :
550
556
tcs = TypeConstraintSolver ()
551
557
fn_decl = mod_decls .get_function_decl (callable_ref )
0 commit comments