Skip to content

Commit 53ba28f

Browse files
Refactor declarations to keep private things private
1 parent 6c2e36f commit 53ba28f

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

python/egglog/declarations.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,9 @@ def register_callable_ref(self, ref: CallableRef, egg_name: str) -> None:
185185
def get_egg_fn(self, ref: CallableRef) -> str:
186186
return self._callable_ref_to_egg_fn[ref]
187187

188+
def get_egg_sort(self, ref: JustTypeRef) -> str:
189+
return self._type_ref_to_egg_sort[ref]
190+
188191

189192
def constant_function_decl(type_ref: JustTypeRef) -> FunctionDecl:
190193
"""
@@ -216,10 +219,10 @@ def to_commands(self, decls: Declarations) -> Iterable[bindings._Command]:
216219
"""
217220
Register this type with the egg solver.
218221
"""
219-
egg_name = decls._type_ref_to_egg_sort[self]
222+
egg_name = decls.get_egg_sort(self)
220223
for arg in self.args:
221224
yield from decls._register_sort(arg)
222-
arg_sorts = [cast("bindings._Expr", bindings.Var(decls._type_ref_to_egg_sort[a])) for a in self.args]
225+
arg_sorts = [cast("bindings._Expr", bindings.Var(decls.get_egg_sort(a))) for a in self.args]
223226
yield bindings.Sort(egg_name, (self.name, arg_sorts) if arg_sorts else None)
224227

225228
def to_var(self) -> TypeRefWithVars:
@@ -280,7 +283,7 @@ class ClassMethodRef:
280283
method_name: str
281284

282285
def to_egg(self, decls: Declarations) -> str:
283-
return decls._callable_ref_to_egg_fn[self]
286+
return decls.get_egg_fn(self)
284287

285288
def generate_egg_name(self) -> str:
286289
return f"{self.class_name}.{self.method_name}"
@@ -331,8 +334,8 @@ def to_commands(self, decls: Declarations, egg_name: str) -> Iterable[bindings._
331334
# Remove all vars from the type refs, raising an errory if we find one,
332335
# since we cannot create egg functions with vars
333336
bindings.Schema(
334-
[decls._type_ref_to_egg_sort[a] for a in just_arg_types],
335-
decls._type_ref_to_egg_sort[just_return_type],
337+
[decls.get_egg_sort(a) for a in just_arg_types],
338+
decls.get_egg_sort(just_return_type),
336339
),
337340
self.default.to_egg(decls) if self.default else None,
338341
self.merge.to_egg(decls) if self.merge else None,
@@ -434,7 +437,7 @@ def from_egg(cls, decls: Declarations, call: bindings.Call) -> tuple[JustTypeRef
434437

435438
def to_egg(self, decls: Declarations) -> bindings.Call:
436439
"""Convert a Call to an egg Call."""
437-
egg_fn = decls._callable_ref_to_egg_fn[self.callable]
440+
egg_fn = decls.get_egg_fn(self.callable)
438441
return bindings.Call(egg_fn, [a.to_egg(decls) for a in self.args])
439442

440443
def pretty(self, parens=True, **kwargs) -> str:
@@ -578,7 +581,7 @@ class SetDecl:
578581

579582
def to_egg(self, decls: Declarations) -> bindings.Set:
580583
return bindings.Set(
581-
decls._callable_ref_to_egg_fn[self.call.callable],
584+
decls.get_egg_fn(self.call.callable),
582585
[a.to_egg(decls) for a in self.call.args],
583586
self.rhs.to_egg(decls),
584587
)
@@ -589,9 +592,7 @@ class DeleteDecl:
589592
call: CallDecl
590593

591594
def to_egg(self, decls: Declarations) -> bindings.Delete:
592-
return bindings.Delete(
593-
decls._callable_ref_to_egg_fn[self.call.callable], [a.to_egg(decls) for a in self.call.args]
594-
)
595+
return bindings.Delete(decls.get_egg_fn(self.call.callable), [a.to_egg(decls) for a in self.call.args])
595596

596597

597598
@dataclass(frozen=True)

0 commit comments

Comments
 (0)