Skip to content

Commit b16389a

Browse files
Fix extracting generic classmethods again
1 parent 8afe4a2 commit b16389a

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

python/egglog/declarations.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,16 @@ def get_class_decl(self, name: str) -> ClassDecl:
231231
pass
232232
raise KeyError(f"Class {name} not found")
233233

234+
def get_registered_class_args(self, cls_name: str) -> tuple[JustTypeRef, ...]:
235+
"""
236+
Given a class name, returns the first typevar regsisted with args of that class.
237+
"""
238+
for decl in self.all_decls:
239+
for tp in decl._type_ref_to_egg_sort.keys():
240+
if tp.name == cls_name and tp.args:
241+
return tp.args
242+
return ()
243+
234244
def register_class(self, name: str, n_type_vars: int, egg_sort: Optional[str]) -> Iterable[bindings._Command]:
235245
# Register class first
236246
if name in self._decl._classes:
@@ -538,14 +548,10 @@ def from_egg(cls, mod_decls: ModuleDeclarations, call: bindings.Call) -> TypedEx
538548
for callable_ref in mod_decls.get_callable_refs(call.name):
539549
# If this is a classmethod, we might need the type params that were bound for this type
540550
# egglog currently only allows one instantiated type of any generic sort to be used in any program
541-
# So we just lookup what args were registered for thsi sort
551+
# So we just lookup what args were registered for this sort
542552
if isinstance(callable_ref, ClassMethodRef):
543-
for registered_tp in mod_decls._decl._type_ref_to_egg_sort.keys():
544-
if registered_tp.name == callable_ref.class_name:
545-
tcs = TypeConstraintSolver.from_type_parameters(registered_tp.args)
546-
break
547-
else:
548-
raise ValueError(f"Could not find type parameters for class {callable_ref.class_name}")
553+
cls_args = mod_decls.get_registered_class_args(callable_ref.class_name)
554+
tcs = TypeConstraintSolver.from_type_parameters(cls_args)
549555
else:
550556
tcs = TypeConstraintSolver()
551557
fn_decl = mod_decls.get_function_decl(callable_ref)

0 commit comments

Comments
 (0)