|
133 | 133 | },
|
134 | 134 | {
|
135 | 135 | "cell_type": "code",
|
136 |
| - "execution_count": 3, |
| 136 | + "execution_count": 8, |
137 | 137 | "id": "b6424530",
|
138 | 138 | "metadata": {},
|
139 | 139 | "outputs": [
|
|
143 | 143 | "(Dim.named(\"x\") * Dim(10)) * Dim(10)"
|
144 | 144 | ]
|
145 | 145 | },
|
146 |
| - "execution_count": 3, |
| 146 | + "execution_count": 8, |
147 | 147 | "metadata": {},
|
148 | 148 | "output_type": "execute_result"
|
149 | 149 | }
|
|
195 | 195 | ]
|
196 | 196 | },
|
197 | 197 | {
|
| 198 | + "attachments": {}, |
198 | 199 | "cell_type": "markdown",
|
199 | 200 | "id": "167722d1-60b8-452a-ae54-6a8df4db5b00",
|
200 | 201 | "metadata": {},
|
|
221 | 222 | "source": [
|
222 | 223 | "### Testing\n",
|
223 | 224 | "Going back to the notebook, we can test out the that the rewrites are working.\n",
|
224 |
| - "\n", |
225 |
| - "First, we have to add our expression to the egraph. We can do this by defining an empty `let` whcih uses this expression:" |
| 225 | + "We can run some number of iterations and extract out the lowest cost expression which is equivalent to our variable:" |
226 | 226 | ]
|
227 | 227 | },
|
228 | 228 | {
|
229 | 229 | "cell_type": "code",
|
230 | 230 | "execution_count": 5,
|
231 |
| - "id": "29c6d0c1-3249-4597-9d01-ec1fb29dd13f", |
232 |
| - "metadata": { |
233 |
| - "tags": [] |
234 |
| - }, |
235 |
| - "outputs": [], |
236 |
| - "source": [ |
237 |
| - "egraph.register(let(\"\", res))" |
238 |
| - ] |
239 |
| - }, |
240 |
| - { |
241 |
| - "cell_type": "markdown", |
242 |
| - "id": "5dc24d6e-8145-4ab3-b114-024dbb53323f", |
243 |
| - "metadata": {}, |
244 |
| - "source": [ |
245 |
| - "We can then run some number of iterations and extract out the lowest cost expression which is equivalent to our variable:" |
246 |
| - ] |
247 |
| - }, |
248 |
| - { |
249 |
| - "cell_type": "code", |
250 |
| - "execution_count": 6, |
251 | 231 | "id": "31afa12e-da68-4398-91fa-14523f6c099a",
|
252 | 232 | "metadata": {
|
253 | 233 | "tags": []
|
|
259 | 239 | "Dim.named(\"x\") * Dim(100)"
|
260 | 240 | ]
|
261 | 241 | },
|
262 |
| - "execution_count": 6, |
| 242 | + "execution_count": 5, |
263 | 243 | "metadata": {},
|
264 | 244 | "output_type": "execute_result"
|
265 | 245 | }
|
266 | 246 | ],
|
267 | 247 | "source": [
|
268 |
| - "egraph.run(10)\n", |
269 |
| - "egraph.extract(res)" |
270 |
| - ] |
271 |
| - }, |
272 |
| - { |
273 |
| - "cell_type": "markdown", |
274 |
| - "id": "bd08d366-0ebe-4e43-b219-b57fbb13534d", |
275 |
| - "metadata": {}, |
276 |
| - "source": [ |
277 |
| - "We can also extract a number of variants to see all the equivalent expresions, ordered by their cost:" |
278 |
| - ] |
279 |
| - }, |
280 |
| - { |
281 |
| - "cell_type": "code", |
282 |
| - "execution_count": 7, |
283 |
| - "id": "17e39076-0257-4bcd-b1c4-04eb5a79791e", |
284 |
| - "metadata": { |
285 |
| - "tags": [] |
286 |
| - }, |
287 |
| - "outputs": [ |
288 |
| - { |
289 |
| - "data": { |
290 |
| - "text/plain": [ |
291 |
| - "[Dim(100) * Dim.named(\"x\"),\n", |
292 |
| - " Dim.named(\"x\") * Dim(100),\n", |
293 |
| - " Dim(10) * (Dim.named(\"x\") * Dim(10)),\n", |
294 |
| - " (Dim.named(\"x\") * Dim(10)) * Dim(10)]" |
295 |
| - ] |
296 |
| - }, |
297 |
| - "execution_count": 7, |
298 |
| - "metadata": {}, |
299 |
| - "output_type": "execute_result" |
300 |
| - } |
301 |
| - ], |
302 |
| - "source": [ |
303 |
| - "egraph.extract_multiple(res, 10)" |
| 248 | + "egraph.simplify(res, 10)" |
304 | 249 | ]
|
305 | 250 | },
|
306 | 251 | {
|
|
316 | 261 | },
|
317 | 262 | {
|
318 | 263 | "cell_type": "code",
|
319 |
| - "execution_count": 8, |
| 264 | + "execution_count": 9, |
320 | 265 | "id": "c5b96cfb",
|
321 | 266 | "metadata": {},
|
322 | 267 | "outputs": [],
|
|
378 | 323 | },
|
379 | 324 | {
|
380 | 325 | "cell_type": "code",
|
381 |
| - "execution_count": 9, |
| 326 | + "execution_count": 10, |
382 | 327 | "id": "cb2b4fb8",
|
383 | 328 | "metadata": {},
|
384 | 329 | "outputs": [],
|
|
409 | 354 | },
|
410 | 355 | {
|
411 | 356 | "cell_type": "code",
|
412 |
| - "execution_count": 10, |
| 357 | + "execution_count": 13, |
413 | 358 | "id": "8d18be2d",
|
414 | 359 | "metadata": {},
|
415 | 360 | "outputs": [
|
416 | 361 | {
|
417 |
| - "data": { |
418 |
| - "text/plain": [ |
419 |
| - "(Dim.named(\"x\"), Dim.named(\"y\"))" |
420 |
| - ] |
421 |
| - }, |
422 |
| - "execution_count": 10, |
423 |
| - "metadata": {}, |
424 |
| - "output_type": "execute_result" |
| 362 | + "name": "stdout", |
| 363 | + "output_type": "stream", |
| 364 | + "text": [ |
| 365 | + "Dim.named(\"y\")\n", |
| 366 | + "Dim.named(\"x\")\n" |
| 367 | + ] |
425 | 368 | }
|
426 | 369 | ],
|
427 | 370 | "source": [
|
428 | 371 | "# If we multiply two identity matrices, we should be able to get the number of columns of the result\n",
|
429 | 372 | "x = Matrix.identity(Dim.named(\"x\"))\n",
|
430 | 373 | "y = Matrix.identity(Dim.named(\"y\"))\n",
|
431 | 374 | "x_mult_y = x @ y\n",
|
432 |
| - "x_mult_y_ncols = x_mult_y.ncols()\n", |
433 |
| - "x_mult_y_nrows = x_mult_y.nrows()\n", |
434 |
| - "\n", |
435 |
| - "egraph.register(let(\"\", x_mult_y_ncols), let(\"\", x_mult_y_nrows))\n", |
436 |
| - "\n", |
437 |
| - "egraph.run(10)\n", |
438 |
| - "egraph.extract(x_mult_y_nrows), egraph.extract(x_mult_y_ncols)" |
| 375 | + "print(egraph.simplify(x_mult_y.ncols(), 10))\n", |
| 376 | + "print(egraph.simplify(x_mult_y.nrows(), 10))" |
439 | 377 | ]
|
440 | 378 | },
|
441 | 379 | {
|
|
451 | 389 | },
|
452 | 390 | {
|
453 | 391 | "cell_type": "code",
|
454 |
| - "execution_count": 11, |
| 392 | + "execution_count": 14, |
455 | 393 | "id": "18a91684",
|
456 | 394 | "metadata": {},
|
457 | 395 | "outputs": [],
|
|
488 | 426 | },
|
489 | 427 | {
|
490 | 428 | "cell_type": "code",
|
491 |
| - "execution_count": 12, |
| 429 | + "execution_count": 15, |
492 | 430 | "id": "303ce7f3",
|
493 | 431 | "metadata": {},
|
494 | 432 | "outputs": [],
|
|
522 | 460 | },
|
523 | 461 | {
|
524 | 462 | "cell_type": "code",
|
525 |
| - "execution_count": 13, |
| 463 | + "execution_count": 16, |
526 | 464 | "id": "bb50ade6",
|
527 | 465 | "metadata": {},
|
528 | 466 | "outputs": [
|
529 | 467 | {
|
530 | 468 | "data": {
|
531 | 469 | "text/plain": [
|
532 |
| - "[kron(Matrix.identity(Dim.named(\"n\")), Matrix.named(\"B\")) @ kron(Matrix.named(\"A\"), Matrix.identity(Dim.named(\"m\"))),\n", |
533 |
| - " kron(Matrix.named(\"A\"), Matrix.named(\"B\"))]" |
| 470 | + "kron(Matrix.named(\"A\"), Matrix.named(\"B\"))" |
534 | 471 | ]
|
535 | 472 | },
|
536 |
| - "execution_count": 13, |
| 473 | + "execution_count": 16, |
537 | 474 | "metadata": {},
|
538 | 475 | "output_type": "execute_result"
|
539 | 476 | }
|
|
554 | 491 | ")\n",
|
555 | 492 | "# Create an example which should equal the kronecker product of A and B\n",
|
556 | 493 | "ex1 = kron(Matrix.identity(n), B) @ kron(A, Matrix.identity(m))\n",
|
557 |
| - "egraph.register(let(\"\", ex1))\n", |
558 |
| - "\n", |
559 |
| - "egraph.run(20)\n", |
560 |
| - "# Verify it matches the expected result\n", |
561 |
| - "egraph.check(eq(ex1).to(kron(A, B)))\n", |
562 |
| - "egraph.extract_multiple(ex1, 10)" |
| 494 | + "egraph.simplify(ex1, 20)" |
563 | 495 | ]
|
564 | 496 | },
|
565 | 497 | {
|
|
573 | 505 | },
|
574 | 506 | {
|
575 | 507 | "cell_type": "code",
|
576 |
| - "execution_count": 14, |
| 508 | + "execution_count": 17, |
577 | 509 | "id": "d8dea199",
|
578 | 510 | "metadata": {},
|
579 | 511 | "outputs": [
|
580 | 512 | {
|
581 | 513 | "data": {
|
582 | 514 | "text/plain": [
|
583 |
| - "[(kron(Matrix.identity(Dim.named(\"p\")), Matrix.named(\"C\")) @ kron(Matrix.identity(Dim.named(\"n\")), Matrix.identity(Dim.named(\"m\")))) @ kron(\n", |
584 |
| - " Matrix.named(\"A\"), Matrix.identity(Dim.named(\"m\"))\n", |
585 |
| - " ),\n", |
586 |
| - " kron(Matrix.identity(Dim.named(\"p\")), Matrix.named(\"C\")) @ kron(Matrix.named(\"A\"), Matrix.identity(Dim.named(\"m\")))]" |
| 515 | + "kron(Matrix.identity(Dim.named(\"p\")), Matrix.named(\"C\")) @ kron(Matrix.named(\"A\"), Matrix.identity(Dim.named(\"m\")))" |
587 | 516 | ]
|
588 | 517 | },
|
589 |
| - "execution_count": 14, |
| 518 | + "execution_count": 17, |
590 | 519 | "metadata": {},
|
591 | 520 | "output_type": "execute_result"
|
592 | 521 | }
|
593 | 522 | ],
|
594 | 523 | "source": [
|
595 | 524 | "ex2 = kron(Matrix.identity(p), C) @ kron(A, Matrix.identity(m))\n",
|
596 |
| - "egraph.register(let(\"\", ex2))\n", |
597 |
| - "\n", |
598 |
| - "egraph.run(10)\n", |
599 |
| - "# Verify it is not simplified\n", |
600 |
| - "egraph.check(ex2 != kron(A, C))\n", |
601 |
| - "egraph.extract_multiple(ex2, 10)" |
| 525 | + "egraph.simplify(ex2, 20)" |
602 | 526 | ]
|
603 | 527 | },
|
604 | 528 | {
|
|
0 commit comments