Skip to content

Commit 6cacb85

Browse files
Add simplify to fix docs
1 parent d177fe7 commit 6cacb85

File tree

3 files changed

+63
-112
lines changed

3 files changed

+63
-112
lines changed

docs/how-to-guides.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,18 @@ file_format: mystnb
66

77
## Parsing and running program strings
88

9-
You can provide your program in a special DSL language and run it with {meth}`egg_smol.bindings.EGraph.parse_and_run_program`:
9+
You can provide your program in a special DSL language. You can parse this with {meth}`egg_smol.bindings.EGraph.parse_program` and then run the result with You can parse this with {meth}`egg_smol.bindings.EGraph.run_program`::
1010

1111
```{code-cell}
1212
from egg_smol.bindings import EGraph
1313
1414
egraph = EGraph()
15-
egraph.parse_and_run_program("(check (= (+ 1 2) 3))")
15+
commands = egraph.parse_program("(check (= (+ 1 2) 3))")
16+
commands
17+
```
18+
19+
```{code-cell}
20+
egraph.run_program(*commands)
1621
```
1722

1823
## Developing this package

docs/tutorials/getting-started.ipynb

Lines changed: 27 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@
133133
},
134134
{
135135
"cell_type": "code",
136-
"execution_count": 3,
136+
"execution_count": 8,
137137
"id": "b6424530",
138138
"metadata": {},
139139
"outputs": [
@@ -143,7 +143,7 @@
143143
"(Dim.named(\"x\") * Dim(10)) * Dim(10)"
144144
]
145145
},
146-
"execution_count": 3,
146+
"execution_count": 8,
147147
"metadata": {},
148148
"output_type": "execute_result"
149149
}
@@ -195,6 +195,7 @@
195195
]
196196
},
197197
{
198+
"attachments": {},
198199
"cell_type": "markdown",
199200
"id": "167722d1-60b8-452a-ae54-6a8df4db5b00",
200201
"metadata": {},
@@ -221,33 +222,12 @@
221222
"source": [
222223
"### Testing\n",
223224
"Going back to the notebook, we can test out the that the rewrites are working.\n",
224-
"\n",
225-
"First, we have to add our expression to the egraph. We can do this by defining an empty `let` whcih uses this expression:"
225+
"We can run some number of iterations and extract out the lowest cost expression which is equivalent to our variable:"
226226
]
227227
},
228228
{
229229
"cell_type": "code",
230230
"execution_count": 5,
231-
"id": "29c6d0c1-3249-4597-9d01-ec1fb29dd13f",
232-
"metadata": {
233-
"tags": []
234-
},
235-
"outputs": [],
236-
"source": [
237-
"egraph.register(let(\"\", res))"
238-
]
239-
},
240-
{
241-
"cell_type": "markdown",
242-
"id": "5dc24d6e-8145-4ab3-b114-024dbb53323f",
243-
"metadata": {},
244-
"source": [
245-
"We can then run some number of iterations and extract out the lowest cost expression which is equivalent to our variable:"
246-
]
247-
},
248-
{
249-
"cell_type": "code",
250-
"execution_count": 6,
251231
"id": "31afa12e-da68-4398-91fa-14523f6c099a",
252232
"metadata": {
253233
"tags": []
@@ -259,48 +239,13 @@
259239
"Dim.named(\"x\") * Dim(100)"
260240
]
261241
},
262-
"execution_count": 6,
242+
"execution_count": 5,
263243
"metadata": {},
264244
"output_type": "execute_result"
265245
}
266246
],
267247
"source": [
268-
"egraph.run(10)\n",
269-
"egraph.extract(res)"
270-
]
271-
},
272-
{
273-
"cell_type": "markdown",
274-
"id": "bd08d366-0ebe-4e43-b219-b57fbb13534d",
275-
"metadata": {},
276-
"source": [
277-
"We can also extract a number of variants to see all the equivalent expresions, ordered by their cost:"
278-
]
279-
},
280-
{
281-
"cell_type": "code",
282-
"execution_count": 7,
283-
"id": "17e39076-0257-4bcd-b1c4-04eb5a79791e",
284-
"metadata": {
285-
"tags": []
286-
},
287-
"outputs": [
288-
{
289-
"data": {
290-
"text/plain": [
291-
"[Dim(100) * Dim.named(\"x\"),\n",
292-
" Dim.named(\"x\") * Dim(100),\n",
293-
" Dim(10) * (Dim.named(\"x\") * Dim(10)),\n",
294-
" (Dim.named(\"x\") * Dim(10)) * Dim(10)]"
295-
]
296-
},
297-
"execution_count": 7,
298-
"metadata": {},
299-
"output_type": "execute_result"
300-
}
301-
],
302-
"source": [
303-
"egraph.extract_multiple(res, 10)"
248+
"egraph.simplify(res, 10)"
304249
]
305250
},
306251
{
@@ -316,7 +261,7 @@
316261
},
317262
{
318263
"cell_type": "code",
319-
"execution_count": 8,
264+
"execution_count": 9,
320265
"id": "c5b96cfb",
321266
"metadata": {},
322267
"outputs": [],
@@ -378,7 +323,7 @@
378323
},
379324
{
380325
"cell_type": "code",
381-
"execution_count": 9,
326+
"execution_count": 10,
382327
"id": "cb2b4fb8",
383328
"metadata": {},
384329
"outputs": [],
@@ -409,33 +354,26 @@
409354
},
410355
{
411356
"cell_type": "code",
412-
"execution_count": 10,
357+
"execution_count": 13,
413358
"id": "8d18be2d",
414359
"metadata": {},
415360
"outputs": [
416361
{
417-
"data": {
418-
"text/plain": [
419-
"(Dim.named(\"x\"), Dim.named(\"y\"))"
420-
]
421-
},
422-
"execution_count": 10,
423-
"metadata": {},
424-
"output_type": "execute_result"
362+
"name": "stdout",
363+
"output_type": "stream",
364+
"text": [
365+
"Dim.named(\"y\")\n",
366+
"Dim.named(\"x\")\n"
367+
]
425368
}
426369
],
427370
"source": [
428371
"# If we multiply two identity matrices, we should be able to get the number of columns of the result\n",
429372
"x = Matrix.identity(Dim.named(\"x\"))\n",
430373
"y = Matrix.identity(Dim.named(\"y\"))\n",
431374
"x_mult_y = x @ y\n",
432-
"x_mult_y_ncols = x_mult_y.ncols()\n",
433-
"x_mult_y_nrows = x_mult_y.nrows()\n",
434-
"\n",
435-
"egraph.register(let(\"\", x_mult_y_ncols), let(\"\", x_mult_y_nrows))\n",
436-
"\n",
437-
"egraph.run(10)\n",
438-
"egraph.extract(x_mult_y_nrows), egraph.extract(x_mult_y_ncols)"
375+
"print(egraph.simplify(x_mult_y.ncols(), 10))\n",
376+
"print(egraph.simplify(x_mult_y.nrows(), 10))"
439377
]
440378
},
441379
{
@@ -451,7 +389,7 @@
451389
},
452390
{
453391
"cell_type": "code",
454-
"execution_count": 11,
392+
"execution_count": 14,
455393
"id": "18a91684",
456394
"metadata": {},
457395
"outputs": [],
@@ -488,7 +426,7 @@
488426
},
489427
{
490428
"cell_type": "code",
491-
"execution_count": 12,
429+
"execution_count": 15,
492430
"id": "303ce7f3",
493431
"metadata": {},
494432
"outputs": [],
@@ -522,18 +460,17 @@
522460
},
523461
{
524462
"cell_type": "code",
525-
"execution_count": 13,
463+
"execution_count": 16,
526464
"id": "bb50ade6",
527465
"metadata": {},
528466
"outputs": [
529467
{
530468
"data": {
531469
"text/plain": [
532-
"[kron(Matrix.identity(Dim.named(\"n\")), Matrix.named(\"B\")) @ kron(Matrix.named(\"A\"), Matrix.identity(Dim.named(\"m\"))),\n",
533-
" kron(Matrix.named(\"A\"), Matrix.named(\"B\"))]"
470+
"kron(Matrix.named(\"A\"), Matrix.named(\"B\"))"
534471
]
535472
},
536-
"execution_count": 13,
473+
"execution_count": 16,
537474
"metadata": {},
538475
"output_type": "execute_result"
539476
}
@@ -554,12 +491,7 @@
554491
")\n",
555492
"# Create an example which should equal the kronecker product of A and B\n",
556493
"ex1 = kron(Matrix.identity(n), B) @ kron(A, Matrix.identity(m))\n",
557-
"egraph.register(let(\"\", ex1))\n",
558-
"\n",
559-
"egraph.run(20)\n",
560-
"# Verify it matches the expected result\n",
561-
"egraph.check(eq(ex1).to(kron(A, B)))\n",
562-
"egraph.extract_multiple(ex1, 10)"
494+
"egraph.simplify(ex1, 20)"
563495
]
564496
},
565497
{
@@ -573,32 +505,24 @@
573505
},
574506
{
575507
"cell_type": "code",
576-
"execution_count": 14,
508+
"execution_count": 17,
577509
"id": "d8dea199",
578510
"metadata": {},
579511
"outputs": [
580512
{
581513
"data": {
582514
"text/plain": [
583-
"[(kron(Matrix.identity(Dim.named(\"p\")), Matrix.named(\"C\")) @ kron(Matrix.identity(Dim.named(\"n\")), Matrix.identity(Dim.named(\"m\")))) @ kron(\n",
584-
" Matrix.named(\"A\"), Matrix.identity(Dim.named(\"m\"))\n",
585-
" ),\n",
586-
" kron(Matrix.identity(Dim.named(\"p\")), Matrix.named(\"C\")) @ kron(Matrix.named(\"A\"), Matrix.identity(Dim.named(\"m\")))]"
515+
"kron(Matrix.identity(Dim.named(\"p\")), Matrix.named(\"C\")) @ kron(Matrix.named(\"A\"), Matrix.identity(Dim.named(\"m\")))"
587516
]
588517
},
589-
"execution_count": 14,
518+
"execution_count": 17,
590519
"metadata": {},
591520
"output_type": "execute_result"
592521
}
593522
],
594523
"source": [
595524
"ex2 = kron(Matrix.identity(p), C) @ kron(A, Matrix.identity(m))\n",
596-
"egraph.register(let(\"\", ex2))\n",
597-
"\n",
598-
"egraph.run(10)\n",
599-
"# Verify it is not simplified\n",
600-
"egraph.check(ex2 != kron(A, C))\n",
601-
"egraph.extract_multiple(ex2, 10)"
525+
"egraph.simplify(ex2, 20)"
602526
]
603527
},
604528
{

python/egg_smol/egraph.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,22 @@ def _get_egraph(self) -> bindings.EGraph:
116116
raise RuntimeError("Cannot get the e-graph")
117117
return self._egraph
118118

119+
def simplify(self, expr: EXPR, limit: int, *until: Fact) -> EXPR:
120+
"""
121+
Simplifies the given expression.
122+
"""
123+
return self._simplify(expr, limit, None, until)
124+
125+
def _simplify(self, expr: EXPR, limit: int, ruleset: Optional[Ruleset], until: tuple[Fact, ...]) -> EXPR:
126+
tp, decl = expr_parts(expr)
127+
egg_expr = decl.to_egg(self._decls)
128+
self._run_program([bindings.Simplify(egg_expr, Config(limit, ruleset, until)._to_egg_config(self._decls))])
129+
extract_report = self._get_egraph().take_extract_report()
130+
if not extract_report:
131+
raise ValueError("No extract report saved")
132+
new_tp, new_decl = tp_and_expr_decl_from_egg(self._decls, extract_report.expr)
133+
return cast(EXPR, RuntimeExpr(self._decls, new_tp, new_decl))
134+
119135
def relation(self, name: str, *tps: Unpack[TS], egg_fn: Optional[str] = None) -> Callable[[Unpack[TS]], Unit]:
120136
"""
121137
Defines a relation, which is the same as a function which returns unit.
@@ -705,6 +721,11 @@ def run(self, limit: int, *until: Fact) -> bindings.RunReport:
705721
"""
706722
return self._egraph._run_schedule(config(limit, self, *until))
707723

724+
def simplify(self, expr: EXPR, limit: int, *until: Fact) -> EXPR:
725+
"""
726+
Simplify the given expression with this ruleset.
727+
"""
728+
return self._egraph._simplify(expr, limit, self, until)
708729

709730
# We use these builders so that when creating these structures we can type check
710731
# if the arguments are the same type of expression
@@ -1037,13 +1058,14 @@ def __str__(self) -> str:
10371058
args_str = ", ".join(map(str, [self.limit, self.ruleset, *self.until]))
10381059
return f"config({args_str})"
10391060

1040-
def _to_egg(self, declerations: Declarations) -> bindings._Schedule:
1041-
return bindings.Run(
1042-
bindings.RunConfig(
1043-
self.ruleset.name if self.ruleset else "",
1044-
self.limit,
1045-
[fact_decl_to_egg(declerations, _fact_to_decl(fact)) for fact in self.until] if self.until else None,
1046-
)
1061+
def _to_egg(self, decls: Declarations) -> bindings._Schedule:
1062+
return bindings.Run(self._to_egg_config(decls))
1063+
1064+
def _to_egg_config(self, decls: Declarations) -> bindings.RunConfig:
1065+
return bindings.RunConfig(
1066+
self.ruleset.name if self.ruleset else "",
1067+
self.limit,
1068+
[fact_decl_to_egg(decls, _fact_to_decl(fact)) for fact in self.until] if self.until else None,
10471069
)
10481070

10491071

0 commit comments

Comments
 (0)