Skip to content

Commit 83c61e7

Browse files
Merge pull request #23 from metadsl/upgrade
Upgrade egglog dependency
2 parents 68895fc + 645c1b3 commit 83c61e7

17 files changed

+292
-163
lines changed

Cargo.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ crate-type = ["cdylib"]
1111
[dependencies]
1212
# Use unreleased version to depend on signature improvements https://github.com/PyO3/pyo3/pull/2702
1313
pyo3 = { version = "0.18.1", features = ["extension-module"] }
14-
egg-smol = { git = "https://github.com/egraphs-good/egglog", rev = "30feaaab88452ec4b6c5f7a199345298bac2dd0f" }
14+
egglog = { git = "https://github.com/egraphs-good/egglog", rev = "39b199d9bfce9cc47d0c54977279c5b04231e717" }
1515
# egg-smol = { path = "../egglog" }
1616
pyo3-log = "0.8.1"
1717
log = "0.4.17"

docs/changelog.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
- Renamed `config()` to `run()` to better match `egglog` command
66
- Fixed `relation` type signature
77
- Added default limit of 1 to `run()` to match `egglog` command and moved to second arg
8+
- Upgraded `egglog` dependency ([changes](https://github.com/egraphs-good/egglog/compare/30feaaab88452ec4b6c5f7a199345298bac2dd0f...39b199d9bfce9cc47d0c54977279c5b04231e717))
9+
- Added `Set` sort and removed set method from `Map`
10+
- Added `Vec` sort
11+
- Added support for variable args for builtin functions, to use in creation of `Vec` and `Set` sorts.
812

913
## 0.4.0 (2023-05-03)
1014

python/egglog/bindings.pyi

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,13 @@ class Set:
9595
args: list[_Expr]
9696
rhs: _Expr
9797

98+
@final
99+
class SetNoTrack:
100+
def __init__(self, lhs: str, args: list[_Expr], rhs: _Expr) -> None: ...
101+
lhs: str
102+
args: list[_Expr]
103+
rhs: _Expr
104+
98105
@final
99106
class Delete:
100107
sym: str
@@ -117,7 +124,7 @@ class Expr_:
117124
def __init__(self, expr: _Expr) -> None: ...
118125
expr: _Expr
119126

120-
_Action = Let | Set | Delete | Union | Panic | Expr_
127+
_Action = Let | Set | SetNoTrack | Delete | Union | Panic | Expr_
121128

122129
##
123130
# Other Structs

python/egglog/builtins.py

Lines changed: 78 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
"StringLike",
2020
"Map",
2121
"Rational",
22+
"Set",
23+
"Vec",
2224
]
2325

2426

@@ -159,41 +161,69 @@ def __init__(self, value: str):
159161

160162
@BUILTINS.class_(egg_sort="Map")
161163
class Map(BaseExpr, Generic[T, V]):
162-
@BUILTINS.method(egg_fn="empty")
164+
@BUILTINS.method(egg_fn="map-empty")
163165
@classmethod
164166
def empty(cls) -> Map[T, V]: # type: ignore[empty-body]
165167
...
166168

167-
@BUILTINS.method(egg_fn="insert")
169+
@BUILTINS.method(egg_fn="map-insert")
168170
def insert(self, key: T, value: V) -> Map[T, V]: # type: ignore[empty-body]
169171
...
170172

171-
@BUILTINS.method(egg_fn="get")
173+
@BUILTINS.method(egg_fn="map-get")
172174
def __getitem__(self, key: T) -> V: # type: ignore[empty-body]
173175
...
174176

175-
@BUILTINS.method(egg_fn="not-contains")
177+
@BUILTINS.method(egg_fn="map-not-contains")
176178
def not_contains(self, key: T) -> Unit: # type: ignore[empty-body]
177179
...
178180

179-
@BUILTINS.method(egg_fn="contains")
181+
@BUILTINS.method(egg_fn="map-contains")
180182
def contains(self, key: T) -> Unit: # type: ignore[empty-body]
181183
...
182184

185+
@BUILTINS.method(egg_fn="map-remove")
186+
def remove(self, key: T) -> Map[T, V]: # type: ignore[empty-body]
187+
...
188+
189+
190+
@BUILTINS.class_(egg_sort="Set")
191+
class Set(BaseExpr, Generic[T]):
192+
@BUILTINS.method(egg_fn="set-of")
193+
def __init__(self, *args: T) -> None:
194+
...
195+
196+
@BUILTINS.method(egg_fn="set-empty")
197+
@classmethod
198+
def empty(cls) -> Set[T]: # type: ignore[empty-body]
199+
...
200+
201+
@BUILTINS.method(egg_fn="set-insert")
202+
def insert(self, value: T) -> Set[T]: # type: ignore[empty-body]
203+
...
204+
205+
@BUILTINS.method(egg_fn="set-not-contains")
206+
def not_contains(self, value: T) -> Unit: # type: ignore[empty-body]
207+
...
208+
209+
@BUILTINS.method(egg_fn="set-contains")
210+
def contains(self, value: T) -> Unit: # type: ignore[empty-body]
211+
...
212+
213+
@BUILTINS.method(egg_fn="set-remove")
214+
def remove(self, value: T) -> Set[T]: # type: ignore[empty-body]
215+
...
216+
183217
@BUILTINS.method(egg_fn="set-union")
184-
def __or__(self, __t: Map[T, V]) -> Map[T, V]: # type: ignore[empty-body]
218+
def __or__(self, other: Set[T]) -> Set[T]: # type: ignore[empty-body]
185219
...
186220

187221
@BUILTINS.method(egg_fn="set-diff")
188-
def __sub__(self, __t: Map[T, V]) -> Map[T, V]: # type: ignore[empty-body]
222+
def __sub__(self, other: Set[T]) -> Set[T]: # type: ignore[empty-body]
189223
...
190224

191225
@BUILTINS.method(egg_fn="set-intersect")
192-
def __and__(self, __t: Map[T, V]) -> Map[T, V]: # type: ignore[empty-body]
193-
...
194-
195-
@BUILTINS.method(egg_fn="map-remove")
196-
def remove(self, key: T) -> Map[T, V]: # type: ignore[empty-body]
226+
def __and__(self, other: Set[T]) -> Set[T]: # type: ignore[empty-body]
197227
...
198228

199229

@@ -203,6 +233,10 @@ class Rational(BaseExpr):
203233
def __init__(self, num: i64Like, den: i64Like):
204234
...
205235

236+
@BUILTINS.method(egg_fn="to-f64")
237+
def to_f64(self) -> f64: # type: ignore[empty-body]
238+
...
239+
206240
@BUILTINS.method(egg_fn="+")
207241
def __add__(self, other: Rational) -> Rational: # type: ignore[empty-body]
208242
...
@@ -262,3 +296,35 @@ def sqrt(self) -> Rational: # type: ignore[empty-body]
262296
@BUILTINS.method(egg_fn="cbrt")
263297
def cbrt(self) -> Rational: # type: ignore[empty-body]
264298
...
299+
300+
301+
@BUILTINS.class_(egg_sort="Vec")
302+
class Vec(BaseExpr, Generic[T]):
303+
@BUILTINS.method(egg_fn="vec-of")
304+
def __init__(self, *args: T) -> None:
305+
...
306+
307+
@BUILTINS.method(egg_fn="vec-empty")
308+
@classmethod
309+
def empty(cls) -> Vec[T]: # type: ignore[empty-body]
310+
...
311+
312+
@BUILTINS.method(egg_fn="vec-append")
313+
def append(self, *others: Vec[T]) -> Vec[T]: # type: ignore[empty-body]
314+
...
315+
316+
@BUILTINS.method(egg_fn="vec-push")
317+
def push(self, value: T) -> Vec[T]: # type: ignore[empty-body]
318+
...
319+
320+
@BUILTINS.method(egg_fn="vec-pop")
321+
def pop(self) -> Vec[T]: # type: ignore[empty-body]
322+
...
323+
324+
@BUILTINS.method(egg_fn="vec-not-contains")
325+
def not_contains(self, value: T) -> Unit: # type: ignore[empty-body]
326+
...
327+
328+
@BUILTINS.method(egg_fn="vec-contains")
329+
def contains(self, value: T) -> Unit: # type: ignore[empty-body]
330+
...

python/egglog/declarations.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,12 +311,15 @@ class FunctionDecl:
311311
# TODO: Add arg name to arg so can call with keyword arg
312312
arg_types: tuple[TypeOrVarRef, ...]
313313
return_type: TypeOrVarRef
314+
var_arg_type: Optional[TypeOrVarRef] = None
314315
cost: Optional[int] = None
315316
default: Optional[ExprDecl] = None
316317
merge: Optional[ExprDecl] = None
317318
merge_action: tuple[ActionDecl, ...] = ()
318319

319320
def to_commands(self, decls: Declarations, egg_name: str) -> Iterable[bindings._Command]:
321+
if self.var_arg_type is not None:
322+
raise NotImplementedError("egglog does not support variable arguments yet.")
320323
just_arg_types = [a.to_just() for a in self.arg_types]
321324
for a in just_arg_types:
322325
yield from decls._register_sort(a)
@@ -406,6 +409,7 @@ class CallDecl:
406409
callable: CallableRef
407410
args: tuple[ExprDecl, ...] = ()
408411
# type parameters that were bound to the callable, if it is a classmethod
412+
# Used for pretty printing classmethod calls with type parameters
409413
bound_tp_params: Optional[tuple[JustTypeRef, ...]] = None
410414

411415
def __post_init__(self):
@@ -424,7 +428,7 @@ def from_egg(cls, decls: Declarations, call: bindings.Call) -> tuple[JustTypeRef
424428
for callable_ref in decls._egg_fn_to_callable_refs[call.name]:
425429
tcs = TypeConstraintSolver()
426430
fn_decl = decls._get_function_decl(callable_ref)
427-
return_tp = tcs.infer_return_type(fn_decl.arg_types, fn_decl.return_type, arg_types)
431+
return_tp = tcs.infer_return_type(fn_decl.arg_types, fn_decl.return_type, fn_decl.var_arg_type, arg_types)
428432
return return_tp, cls(callable_ref, arg_decls)
429433
raise ValueError(f"Could not find callable ref for call {call}")
430434

python/egglog/egraph.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,10 +561,22 @@ def _generate_function_decl(
561561
if first.annotation != Parameter.empty:
562562
raise ValueError(f"First arg of a method must not have an annotation, not {first.annotation}")
563563

564+
# Check that all the params are positional or keyword, and that there is only one var arg at the end
565+
found_var_arg = False
564566
for param in params:
565-
if param.kind != Parameter.POSITIONAL_OR_KEYWORD:
567+
if found_var_arg:
568+
raise ValueError("Can only have a single var arg at the end")
569+
kind = param.kind
570+
if kind == Parameter.VAR_POSITIONAL:
571+
found_var_arg = True
572+
elif kind != Parameter.POSITIONAL_OR_KEYWORD:
566573
raise ValueError(f"Can only register functions with positional or keyword args, not {param.kind}")
567574

575+
if found_var_arg:
576+
var_arg_param, *params = params
577+
var_arg_type = self._resolve_type_annotation(hints[var_arg_param.name], cls_typevars, cls_type_and_name)
578+
else:
579+
var_arg_type = None
568580
arg_types = tuple(self._resolve_type_annotation(hints[t.name], cls_typevars, cls_type_and_name) for t in params)
569581
# If the first arg is a self, and this not an __init__ fn, add this as a typeref
570582
if isinstance(first_arg, (ClassTypeVarRef, TypeRefWithVars)) and not is_init:
@@ -594,6 +606,7 @@ def _generate_function_decl(
594606
)
595607
decl = FunctionDecl(
596608
return_type=return_type,
609+
var_arg_type=var_arg_type,
597610
arg_types=arg_types,
598611
cost=cost,
599612
default=default_decl,

python/egglog/examples/lambda.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def if_(c: Term, t: Term, f: Term) -> Term:
7979
...
8080

8181

82-
StringSet = Map[Var, i64]
82+
StringSet = Set[Var]
8383

8484

8585
@egraph.function(merge=lambda old, new: old & new)
@@ -95,7 +95,7 @@ def freer(t: Term) -> StringSet:
9595
egraph.register(
9696
# freer
9797
rule(eq(t).to(Term.val(v))).then(set_(freer(t)).to(StringSet.empty())),
98-
rule(eq(t).to(Term.var(x))).then(set_(freer(t)).to(StringSet.empty().insert(x, i64(1)))),
98+
rule(eq(t).to(Term.var(x))).then(set_(freer(t)).to(StringSet.empty().insert(x))),
9999
rule(eq(t).to(t1 + t2), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2)).then(set_(freer(t)).to(fv1 | fv2)),
100100
rule(eq(t).to(t1 == t2), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2)).then(set_(freer(t)).to(fv1 | fv2)),
101101
rule(eq(t).to(t1(t2)), eq(freer(t1)).to(fv1), eq(freer(t2)).to(fv2)).then(set_(freer(t)).to(fv1 | fv2)),

python/egglog/runtime.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ def __dir__(self) -> list[str]:
6767
possible_methods.append("__call__")
6868
return possible_methods
6969

70-
def __getitem__(self, args: tuple[RuntimeTypeArgType, ...]) -> RuntimeParamaterizedClass:
70+
def __getitem__(self, args: tuple[RuntimeTypeArgType, ...] | RuntimeTypeArgType) -> RuntimeParamaterizedClass:
71+
if not isinstance(args, tuple):
72+
args = (args,)
7173
tp = JustTypeRef(self.__egg_name__, tuple(class_to_ref(arg) for arg in args))
7274
return RuntimeParamaterizedClass(self.__egg_decls__, tp)
7375

@@ -160,7 +162,7 @@ def _call(
160162
tcs = TypeConstraintSolver()
161163

162164
if fn_decl is not None:
163-
return_tp = tcs.infer_return_type(fn_decl.arg_types, fn_decl.return_type, arg_types)
165+
return_tp = tcs.infer_return_type(fn_decl.arg_types, fn_decl.return_type, fn_decl.var_arg_type, arg_types)
164166
else:
165167
return_tp = JustTypeRef("unit")
166168

python/egglog/type_constraint_solver.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass, field
4-
from typing import Collection
4+
from itertools import chain, repeat
5+
from typing import Collection, Optional
56

67
from .declarations import *
78

@@ -38,24 +39,28 @@ def infer_return_type(
3839
self,
3940
fn_args: Collection[TypeOrVarRef],
4041
fn_return: TypeOrVarRef,
42+
fn_var_args: Optional[TypeOrVarRef],
4143
args: Collection[JustTypeRef],
4244
) -> JustTypeRef:
4345
# Infer the type of each type variable based on the actual types of the arguments
44-
self._infer_typevars_zip(fn_args, args)
46+
self._infer_typevars_zip(fn_args, fn_var_args, args)
4547
# Substitute the type variables with their inferred types
4648
return self._subtitute_typevars(fn_return)
4749

48-
def _infer_typevars_zip(self, fn_args: Collection[TypeOrVarRef], args: Collection[JustTypeRef]) -> None:
49-
if len(fn_args) != len(args):
50+
def _infer_typevars_zip(
51+
self, fn_args: Collection[TypeOrVarRef], fn_var_args: Optional[TypeOrVarRef], args: Collection[JustTypeRef]
52+
) -> None:
53+
if len(fn_args) != len(args) if fn_var_args is None else len(fn_args) > len(args):
5054
raise TypeConstraintError(f"Expected {len(fn_args)} args, got {len(args)}")
51-
for fn_arg, arg in zip(fn_args, args):
55+
all_fn_args = fn_args if fn_var_args is None else chain(fn_args, repeat(fn_var_args))
56+
for fn_arg, arg in zip(all_fn_args, args):
5257
self._infer_typevars(fn_arg, arg)
5358

5459
def _infer_typevars(self, fn_arg: TypeOrVarRef, arg: JustTypeRef) -> None:
5560
if isinstance(fn_arg, TypeRefWithVars):
5661
if fn_arg.name != arg.name:
5762
raise TypeConstraintError(f"Expected {fn_arg.name}, got {arg.name}")
58-
self._infer_typevars_zip(fn_arg.args, arg.args)
63+
self._infer_typevars_zip(fn_arg.args, None, arg.args)
5964
elif fn_arg.index not in self._cls_typevar_index_to_type:
6065
self._cls_typevar_index_to_type[fn_arg.index] = arg
6166
elif self._cls_typevar_index_to_type[fn_arg.index] != arg:

0 commit comments

Comments
 (0)