Skip to content

Commit 788ea64

Browse files
Merge pull request #12 from metadsl/increase-line-length
Increase max line length to 120
2 parents 82e7a42 + 895c2a8 commit 788ea64

File tree

10 files changed

+68
-214
lines changed

10 files changed

+68
-214
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ docs = [
4242
]
4343

4444

45+
[tool.black]
46+
line-length = 120
47+
4548
[tool.isort]
4649
profile = "black"
4750
skip_gitignore = true

python/egg_smol/bindings.pyi

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,13 @@ def parse(input: str) -> list[_Command]: ...
77
class EGraph:
88
def parse_and_run_program(self, input: str) -> list[str]: ...
99
def declare_constructor(self, variant: Variant, sort: str) -> None: ...
10-
def declare_sort(
11-
self, name: str, presort_and_args: tuple[str, list[_Expr]] | None = None
12-
) -> None: ...
10+
def declare_sort(self, name: str, presort_and_args: tuple[str, list[_Expr]] | None = None) -> None: ...
1311
def declare_function(self, decl: FunctionDecl) -> None: ...
1412
def define(self, name: str, expr: _Expr, cost: int | None = None) -> None: ...
1513
def add_rewrite(self, rewrite: Rewrite) -> str: ...
1614
def run_rules(self, limit: int) -> tuple[timedelta, timedelta, timedelta]: ...
1715
def check_fact(self, fact: _Fact) -> None: ...
18-
def extract_expr(
19-
self, expr: _Expr, variants: int = 0
20-
) -> tuple[int, _Expr, list[_Expr]]: ...
16+
def extract_expr(self, expr: _Expr, variants: int = 0) -> tuple[int, _Expr, list[_Expr]]: ...
2117
def add_rule(self, rule: Rule) -> str: ...
2218
def eval_actions(self, *actions: _Action) -> None: ...
2319
def push(self) -> None: ...
@@ -133,9 +129,7 @@ class FunctionDecl:
133129

134130
@final
135131
class Variant:
136-
def __init__(
137-
self, name: str, types: list[str], cost: int | None = None
138-
) -> None: ...
132+
def __init__(self, name: str, types: list[str], cost: int | None = None) -> None: ...
139133
name: str
140134
types: list[str]
141135
cost: int | None
@@ -158,9 +152,7 @@ class Rewrite:
158152
rhs: _Expr
159153
conditions: list[_Fact]
160154

161-
def __init__(
162-
self, lhs: _Expr, rhs: _Expr, conditions: list[_Fact] = []
163-
) -> None: ...
155+
def __init__(self, lhs: _Expr, rhs: _Expr, conditions: list[_Fact] = []) -> None: ...
164156

165157
@final
166158
class Datatype:

python/egg_smol/declarations.py

Lines changed: 16 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,7 @@ class Declarations:
8787
# Bidirectional mapping between egg function names and python callable references.
8888
# Note that there are possibly mutliple callable references for a single egg function name, like `+`
8989
# for both int and rational classes.
90-
egg_fn_to_callable_refs: defaultdict[str, set[CallableRef]] = field(
91-
default_factory=lambda: defaultdict(set)
92-
)
90+
egg_fn_to_callable_refs: defaultdict[str, set[CallableRef]] = field(default_factory=lambda: defaultdict(set))
9391
callable_ref_to_egg_fn: dict[CallableRef, str] = field(default_factory=dict)
9492

9593
# Bidirectional mapping between egg sort names and python type references.
@@ -119,9 +117,7 @@ def get_function_decl(self, ref: CallableRef) -> FunctionDecl:
119117
return self.constants[ref.name].to_function_decl()
120118
assert_never(ref)
121119

122-
def register_sort(
123-
self, type_ref: JustTypeRef, egg_name: Optional[str] = None
124-
) -> str:
120+
def register_sort(self, type_ref: JustTypeRef, egg_name: Optional[str] = None) -> str:
125121
egg_name = egg_name or type_ref.generate_egg_name()
126122
if egg_name in self.egg_sort_to_type_ref:
127123
raise ValueError(f"Sort {egg_name} is already registered.")
@@ -154,10 +150,7 @@ def to_egg(self, decls: Declarations, egraph: bindings.EGraph) -> str:
154150
raise ValueError(f"Type {self.name} is not registered.")
155151
# If this is a type with arguments and it is not registered, then we need to register it
156152
egg_name = decls.register_sort(self)
157-
arg_sorts = [
158-
cast(bindings._Expr, bindings.Var(a.to_egg(decls, egraph)))
159-
for a in self.args
160-
]
153+
arg_sorts = [cast(bindings._Expr, bindings.Var(a.to_egg(decls, egraph))) for a in self.args]
161154
egraph.declare_sort(egg_name, (self.name, arg_sorts))
162155
return egg_name
163156

@@ -265,9 +258,7 @@ class FunctionDecl:
265258
default: Optional[ExprDecl] = None
266259
merge: Optional[ExprDecl] = None
267260

268-
def to_egg(
269-
self, decls: Declarations, egraph: bindings.EGraph, ref: CallableRef
270-
) -> bindings.FunctionDecl:
261+
def to_egg(self, decls: Declarations, egraph: bindings.EGraph, ref: CallableRef) -> bindings.FunctionDecl:
271262
return bindings.FunctionDecl(
272263
decls.callable_ref_to_egg_fn[ref],
273264
# Remove all vars from the type refs, raising an errory if we find one,
@@ -288,9 +279,7 @@ class VarDecl:
288279

289280
@classmethod
290281
def from_egg(cls, var: bindings.Var) -> tuple[JustTypeRef, LitDecl]:
291-
raise NotImplementedError(
292-
"Cannot turn var into egg type because typing unknown."
293-
)
282+
raise NotImplementedError("Cannot turn var into egg type because typing unknown.")
294283

295284
def to_egg(self, _decls: Declarations) -> bindings.Var:
296285
return bindings.Var(self.name)
@@ -349,14 +338,10 @@ class CallDecl:
349338

350339
def __post_init__(self):
351340
if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef):
352-
raise ValueError(
353-
"Cannot bind type parameters to a non-class method callable."
354-
)
341+
raise ValueError("Cannot bind type parameters to a non-class method callable.")
355342

356343
@classmethod
357-
def from_egg(
358-
cls, decls: Declarations, call: bindings.Call
359-
) -> tuple[JustTypeRef, CallDecl]:
344+
def from_egg(cls, decls: Declarations, call: bindings.Call) -> tuple[JustTypeRef, CallDecl]:
360345
from .type_constraint_solver import TypeConstraintSolver
361346

362347
results = [tp_and_expr_decl_from_egg(decls, a) for a in call.args]
@@ -367,9 +352,7 @@ def from_egg(
367352
for callable_ref in decls.egg_fn_to_callable_refs[call.name]:
368353
tcs = TypeConstraintSolver()
369354
fn_decl = decls.get_function_decl(callable_ref)
370-
return_tp = tcs.infer_return_type(
371-
fn_decl.arg_types, fn_decl.return_type, arg_types
372-
)
355+
return_tp = tcs.infer_return_type(fn_decl.arg_types, fn_decl.return_type, arg_types)
373356
return return_tp, cls(callable_ref, arg_decls)
374357
raise ValueError(f"Could not find callable ref for call {call}")
375358

@@ -421,34 +404,12 @@ def test_expr_pretty():
421404
assert LitDecl("foo").pretty() == 'String("foo")'
422405
assert LitDecl(None).pretty() == "unit()"
423406
assert CallDecl(FunctionRef("foo"), (VarDecl("x"),)).pretty() == "foo(x)"
424-
assert (
425-
CallDecl(
426-
FunctionRef("foo"), (VarDecl("x"), VarDecl("y"), VarDecl("z"))
427-
).pretty()
428-
== "foo(x, y, z)"
429-
)
430-
assert (
431-
CallDecl(MethodRef("foo", "__add__"), (VarDecl("x"), VarDecl("y"))).pretty()
432-
== "x + y"
433-
)
434-
assert (
435-
CallDecl(MethodRef("foo", "__getitem__"), (VarDecl("x"), VarDecl("y"))).pretty()
436-
== "x[y]"
437-
)
438-
assert (
439-
CallDecl(
440-
ClassMethodRef("foo", "__init__"), (VarDecl("x"), VarDecl("y"))
441-
).pretty()
442-
== "foo(x, y)"
443-
)
444-
assert (
445-
CallDecl(ClassMethodRef("foo", "bar"), (VarDecl("x"), VarDecl("y"))).pretty()
446-
== "foo.bar(x, y)"
447-
)
448-
assert (
449-
CallDecl(MethodRef("foo", "__call__"), (VarDecl("x"), VarDecl("y"))).pretty()
450-
== "x(y)"
451-
)
407+
assert CallDecl(FunctionRef("foo"), (VarDecl("x"), VarDecl("y"), VarDecl("z"))).pretty() == "foo(x, y, z)"
408+
assert CallDecl(MethodRef("foo", "__add__"), (VarDecl("x"), VarDecl("y"))).pretty() == "x + y"
409+
assert CallDecl(MethodRef("foo", "__getitem__"), (VarDecl("x"), VarDecl("y"))).pretty() == "x[y]"
410+
assert CallDecl(ClassMethodRef("foo", "__init__"), (VarDecl("x"), VarDecl("y"))).pretty() == "foo(x, y)"
411+
assert CallDecl(ClassMethodRef("foo", "bar"), (VarDecl("x"), VarDecl("y"))).pretty() == "foo.bar(x, y)"
412+
assert CallDecl(MethodRef("foo", "__call__"), (VarDecl("x"), VarDecl("y"))).pretty() == "x(y)"
452413
assert (
453414
CallDecl(
454415
ClassMethodRef("Map", "__init__"),
@@ -462,9 +423,7 @@ def test_expr_pretty():
462423
ExprDecl = Union[VarDecl, LitDecl, CallDecl]
463424

464425

465-
def tp_and_expr_decl_from_egg(
466-
decls: Declarations, expr: bindings._Expr
467-
) -> tuple[JustTypeRef, ExprDecl]:
426+
def tp_and_expr_decl_from_egg(decls: Declarations, expr: bindings._Expr) -> tuple[JustTypeRef, ExprDecl]:
468427
if isinstance(expr, bindings.Var):
469428
return VarDecl.from_egg(expr)
470429
if isinstance(expr, bindings.Lit):
@@ -557,9 +516,7 @@ class DeleteDecl:
557516
call: CallDecl
558517

559518
def to_egg(self, decls: Declarations) -> bindings.Delete:
560-
return bindings.Delete(
561-
self.call.callable.to_egg(decls), [a.to_egg(decls) for a in self.call.args]
562-
)
519+
return bindings.Delete(self.call.callable.to_egg(decls), [a.to_egg(decls) for a in self.call.args])
563520

564521

565522
@dataclass(frozen=True)

python/egg_smol/egraph.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,9 @@ def extract_multiple(self, expr: EXPR, n: int) -> list[EXPR]:
5959
"""
6060
tp, decl = expr_parts(expr)
6161
egg_expr = decl.to_egg(self._decls)
62-
_cost, new_egg_expr, variants = self._egraph.extract_expr(
63-
egg_expr, variants=n + 1
64-
)
65-
new_decls = [
66-
tp_and_expr_decl_from_egg(self._decls, egg_expr)[1]
67-
for egg_expr in variants[::-1]
68-
]
69-
return [
70-
cast(EXPR, RuntimeExpr(self._decls, tp, new_decl)) for new_decl in new_decls
71-
]
62+
_cost, new_egg_expr, variants = self._egraph.extract_expr(egg_expr, variants=n + 1)
63+
new_decls = [tp_and_expr_decl_from_egg(self._decls, egg_expr)[1] for egg_expr in variants[::-1]]
64+
return [cast(EXPR, RuntimeExpr(self._decls, tp, new_decl)) for new_decl in new_decls]
7265

7366
def define(self, name: str, expr: EXPR, cost: Optional[int] = None) -> EXPR:
7467
"""

python/egg_smol/monkeypatch.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ def _evaluate_monkeypatch(self, globalns, localns):
2828
"Forward references must evaluate to types.",
2929
is_argument=self.__forward_is_argument__,
3030
)
31-
self.__forward_value__ = typing._eval_type( # type: ignore
32-
type_, globalns, localns
33-
)
31+
self.__forward_value__ = typing._eval_type(type_, globalns, localns) # type: ignore
3432
self.__forward_evaluated__ = True
3533
return self.__forward_value__

python/egg_smol/registry.py

Lines changed: 11 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,7 @@ def _class(
132132
"""
133133
cls_name = cls.__name__
134134
# Get all the methods from the class
135-
cls_dict: dict[str, Any] = {
136-
k: v for k, v in cls.__dict__.items() if k not in IGNORED_ATTRIBUTES
137-
}
135+
cls_dict: dict[str, Any] = {k: v for k, v in cls.__dict__.items() if k not in IGNORED_ATTRIBUTES}
138136
parameters: list[TypeVar] = cls_dict.pop("__parameters__", [])
139137

140138
# Register class first
@@ -147,9 +145,7 @@ def _class(
147145
self._on_register_sort(egg_sort or cls_name)
148146

149147
# The type ref of self is paramterized by the type vars
150-
slf_type_ref = TypeRefWithVars(
151-
cls_name, tuple(ClassTypeVarRef(i) for i in range(n_type_vars))
152-
)
148+
slf_type_ref = TypeRefWithVars(cls_name, tuple(ClassTypeVarRef(i) for i in range(n_type_vars)))
153149

154150
# Then register each of its methods
155151
for method_name, method in cls_dict.items():
@@ -321,9 +317,7 @@ def _generate_function_decl(
321317
cls_type_and_name: Optional[tuple[type | RuntimeClass, str]] = None,
322318
) -> FunctionDecl:
323319
if not isinstance(fn, FunctionType):
324-
raise NotImplementedError(
325-
f"Can only generate function decls for functions not {fn} {type(fn)}"
326-
)
320+
raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}")
327321

328322
hint_globals = fn.__globals__.copy()
329323

@@ -336,31 +330,20 @@ def _generate_function_decl(
336330
raise ValueError("Init function must have a self type")
337331
return_type = first_arg
338332
else:
339-
return_type = self._resolve_type_annotation(
340-
hints["return"], cls_typevars, cls_type_and_name
341-
)
333+
return_type = self._resolve_type_annotation(hints["return"], cls_typevars, cls_type_and_name)
342334

343335
params = list(signature(fn).parameters.values())
344336
# Remove first arg if this is a classmethod or a method, since it won't have an annotation
345337
if first_arg is not None:
346338
first, *params = params
347339
if first.annotation != Parameter.empty:
348-
raise ValueError(
349-
f"First arg of a method must not have an annotation, not {first.annotation}"
350-
)
340+
raise ValueError(f"First arg of a method must not have an annotation, not {first.annotation}")
351341

352342
for param in params:
353343
if param.kind != Parameter.POSITIONAL_OR_KEYWORD:
354-
raise ValueError(
355-
f"Can only register functions with positional or keyword args, not {param.kind}"
356-
)
344+
raise ValueError(f"Can only register functions with positional or keyword args, not {param.kind}")
357345

358-
arg_types = tuple(
359-
self._resolve_type_annotation(
360-
hints[t.name], cls_typevars, cls_type_and_name
361-
)
362-
for t in params
363-
)
346+
arg_types = tuple(self._resolve_type_annotation(hints[t.name], cls_typevars, cls_type_and_name) for t in params)
364347
# If the first arg is a self, and this not an __init__ fn, add this as a typeref
365348
if isinstance(first_arg, (ClassTypeVarRef, TypeRefWithVars)) and not is_init:
366349
arg_types = (first_arg,) + arg_types
@@ -405,13 +388,9 @@ def _resolve_type_annotation(
405388
raise TypeError("Union types are only supported for type promotion")
406389
fst, snd = args
407390
if fst in {int, str}:
408-
return self._resolve_type_annotation(
409-
snd, cls_typevars, cls_type_and_name
410-
)
391+
return self._resolve_type_annotation(snd, cls_typevars, cls_type_and_name)
411392
if snd in {int, str}:
412-
return self._resolve_type_annotation(
413-
fst, cls_typevars, cls_type_and_name
414-
)
393+
return self._resolve_type_annotation(fst, cls_typevars, cls_type_and_name)
415394
raise TypeError("Union types are only supported for type promotion")
416395

417396
# If this is the type for the class, use the class name
@@ -471,9 +450,7 @@ class _WrappedMethod(Generic[P, EXPR]):
471450
fn: Callable[P, EXPR]
472451

473452
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> EXPR:
474-
raise NotImplementedError(
475-
"We should never call a wrapped method. Did you forget to wrap the class?"
476-
)
453+
raise NotImplementedError("We should never call a wrapped method. Did you forget to wrap the class?")
477454

478455

479456
class BaseExpr:
@@ -724,9 +701,7 @@ def __str__(self) -> str:
724701
def _to_decl(self) -> SetDecl:
725702
lhs = expr_parts(self.lhs)[1]
726703
if not isinstance(lhs, CallDecl):
727-
raise ValueError(
728-
f"Can only create a call with a call for the lhs, got {lhs}"
729-
)
704+
raise ValueError(f"Can only create a call with a call for the lhs, got {lhs}")
730705
return SetDecl(lhs, expr_parts(self.rhs)[1])
731706

732707

0 commit comments

Comments
 (0)