From 14c8c0d245fa526cbd2d5e6766ad959cd062181f Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Thu, 10 Aug 2023 09:47:32 -0400 Subject: [PATCH 1/2] Make conversions transitive and make getitem more comprehensive --- docs/changelog.md | 1 + docs/reference/egglog-translation.md | 2 + docs/tutorials/array-api.ipynb | 226 +-------------------------- python/egglog/exp/array_api.py | 109 +++++++++++-- python/egglog/runtime.py | 36 ++++- python/tests/test_convert.py | 78 +++++++++ 6 files changed, 219 insertions(+), 233 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index c75f2199..b1b173f0 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -21,6 +21,7 @@ _This project uses semantic versioning. Before 1.0.0, this means that every brea - Upgraded `egg-smol` dependency ([changes](https://github.com/saulshanabrook/egg-smol/compare/353c4387640019bd2066991ee0488dc6d5c54168...2ac80cb1162c61baef295d8e6d00351bfe84883f)) - Add support for functions which mutates their args, like `__setitem__` [#35](https://github.com/metadsl/egglog-python/pull/35) +- Makes conversions transitive ## 0.5.1 (2023-07-18) diff --git a/docs/reference/egglog-translation.md b/docs/reference/egglog-translation.md index a76a4d8e..d83df385 100644 --- a/docs/reference/egglog-translation.md +++ b/docs/reference/egglog-translation.md @@ -255,6 +255,8 @@ Math(2) + 30 + "x" Math(2) + Math(i64(30)) + Math.var(String("x")) ``` +Regstering a conversion from A to B will also register all transitively reachable conversions from A to B. + ### Declarations In egglog, the `(declare ...)` command is syntactic sugar for a nullary function. In Python, these can be declare either as class variables or with the toplevel `egraph.constant` function: diff --git a/docs/tutorials/array-api.ipynb b/docs/tutorials/array-api.ipynb index 9c97eea7..4f0a0c6f 100644 --- a/docs/tutorials/array-api.ipynb +++ b/docs/tutorials/array-api.ipynb @@ -66,11 +66,11 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "import torch\n", + "# import torch\n", "\n", "# fit(torch.asarray(X), torch.asarray(y))" ] @@ -84,7 +84,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -94,222 +94,9 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "DType.float64 == NDArray.var(\"X\").dtype\n", - " -> DType.float64 == NDArray.var(\"X\").dtype\n", - " -> TRUE\n", - "asarray(NDArray.var(\"X\")).ndim == Int(0)\n", - " -> NDArray.var(\"X\").ndim == Int(0)\n", - " -> FALSE\n", - "asarray(NDArray.var(\"X\")).ndim == Int(1)\n", - " -> NDArray.var(\"X\").ndim == Int(1)\n", - " -> FALSE\n", - "asarray(NDArray.var(\"X\")).ndim >= Int(3)\n", - " -> NDArray.var(\"X\").ndim >= Int(3)\n", - " -> FALSE\n", - "asarray(asarray(NDArray.var(\"X\"))).dtype == DType.object\n", - " -> NDArray.var(\"X\").dtype == DType.object\n", - " -> FALSE\n", - "isdtype(asarray(asarray(NDArray.var(\"X\"))).dtype, (IsDtypeKind.string(\"real floating\") | (IsDtypeKind.string(\"complex floating\") | IsDtypeKind.NULL)))\n", - " -> isdtype(NDArray.var(\"X\").dtype, (IsDtypeKind.string(\"real floating\") | IsDtypeKind.string(\"complex floating\")))\n", - " -> TRUE\n", - "isfinite(sum(asarray(asarray(NDArray.var(\"X\"))))).to_bool()\n", - " -> isfinite(sum(NDArray.var(\"X\"))).to_bool()\n", - " -> TRUE\n", - "asarray(NDArray.var(\"X\")).shape.length()\n", - " -> NDArray.var(\"X\").ndim\n", - " -> Int(2)\n", - "asarray(NDArray.var(\"X\")).shape[Int(0)] < Int(2)\n", - " -> NDArray.var(\"X\").shape[Int(0)] < Int(2)\n", - " -> FALSE\n", - "asarray(NDArray.var(\"X\")).ndim == Int(2)\n", - " -> NDArray.var(\"X\").ndim == Int(2)\n", - " -> TRUE\n", - "asarray(NDArray.var(\"X\")).shape[Int(1)] < Int(1)\n", - " -> NDArray.var(\"X\").shape[Int(1)] < Int(1)\n", - " -> FALSE\n", - "asarray(NDArray.var(\"y\")).ndim >= Int(3)\n", - " -> NDArray.var(\"y\").ndim >= Int(3)\n", - " -> FALSE\n", - "asarray(NDArray.var(\"y\")).ndim == Int(2)\n", - " -> NDArray.var(\"y\").ndim == Int(2)\n", - " -> FALSE\n", - "asarray(NDArray.var(\"y\")).shape.length()\n", - " -> NDArray.var(\"y\").ndim\n", - " -> Int(1)\n", - "asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))).dtype == DType.object\n", - " -> NDArray.var(\"y\").dtype == DType.object\n", - " -> FALSE\n", - "isdtype(\n", - " asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))).dtype,\n", - " (IsDtypeKind.string(\"real floating\") | (IsDtypeKind.string(\"complex floating\") | IsDtypeKind.NULL)),\n", - ")\n", - " -> isdtype(NDArray.var(\"y\").dtype, (IsDtypeKind.string(\"real floating\") | IsDtypeKind.string(\"complex floating\")))\n", - " -> FALSE\n", - "asarray(NDArray.var(\"X\")).shape.length()\n", - " -> NDArray.var(\"X\").ndim\n", - " -> Int(2)\n", - "asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))).shape.length()\n", - " -> Int(1)\n", - " -> Int(1)\n", - "asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))).shape[Int(0)] < asarray(NDArray.var(\"X\")).shape[Int(0)]\n", - " -> NDArray.var(\"y\").size < NDArray.var(\"X\").shape[Int(0)]\n", - " -> FALSE\n", - "asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))).shape[Int(0)] > asarray(NDArray.var(\"X\")).shape[Int(0)]\n", - " -> NDArray.var(\"y\").size > NDArray.var(\"X\").shape[Int(0)]\n", - " -> FALSE\n", - "asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))).shape[Int(0)] == asarray(NDArray.var(\"X\")).shape[Int(0)]\n", - " -> NDArray.var(\"y\").size == NDArray.var(\"X\").shape[Int(0)]\n", - " -> TRUE\n", - "asarray(NDArray.var(\"X\")).shape.length()\n", - " -> NDArray.var(\"X\").ndim\n", - " -> Int(2)\n", - "asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))).ndim == Int(2)\n", - " -> FALSE\n", - " -> FALSE\n", - "asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))).ndim == Int(1)\n", - " -> TRUE\n", - " -> TRUE\n", - "asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))).shape.length()\n", - " -> Int(1)\n", - " -> Int(1)\n", - "asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))).shape[Int(0)] == Int(0)\n", - " -> NDArray.var(\"y\").size == Int(0)\n", - " -> FALSE\n", - "asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))).dtype == DType.object\n", - " -> NDArray.var(\"y\").dtype == DType.object\n", - " -> FALSE\n", - "asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))).ndim == Int(2)\n", - " -> FALSE\n", - " -> FALSE\n", - "isdtype(asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))).dtype, IsDtypeKind.string(\"real floating\"))\n", - " -> isdtype(NDArray.var(\"y\").dtype, IsDtypeKind.string(\"real floating\"))\n", - " -> FALSE\n", - "unique_values(asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))))).shape[Int(0)] > Int(2)\n", - " -> unique_values(reshape(NDArray.var(\"y\"), TupleInt(Int(-1)))).shape[Int(0)] > Int(2)\n", - " -> TRUE\n", - "asarray(NDArray.var(\"X\")).shape.length()\n", - " -> NDArray.var(\"X\").ndim\n", - " -> Int(2)\n", - "asarray(NDArray.var(\"X\")).shape[Int(0)] == unique_values(\n", - " concat((TupleNDArray(unique_values(asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))))) + TupleNDArray.EMPTY))\n", - ").shape[Int(0)]\n", - " -> NDArray.var(\"X\").shape[Int(0)] == unique_values(reshape(NDArray.var(\"y\"), TupleInt(Int(-1)))).shape[Int(0)]\n", - " -> FALSE\n", - "unique_counts(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))).length()\n", - " -> Int(2)\n", - " -> Int(2)\n", - "asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))).shape[Int(0)]\n", - " -> NDArray.var(\"y\").size\n", - " -> Int(150)\n", - "any(\n", - " (\n", - " (\n", - " astype(unique_counts(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))))[Int(1)], asarray(NDArray.var(\"X\")).dtype)\n", - " / NDArray.scalar_float(Float(150.0))\n", - " )\n", - " < NDArray.scalar_int(Int(0))\n", - " )\n", - ").to_bool()\n", - " -> FALSE\n", - " -> FALSE\n", - "(\n", - " abs(\n", - " (\n", - " sum(\n", - " (\n", - " astype(unique_counts(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))))[Int(1)], asarray(NDArray.var(\"X\")).dtype)\n", - " / NDArray.scalar_float(Float(150.0))\n", - " )\n", - " )\n", - " - NDArray.scalar_float(Float(1.0))\n", - " )\n", - " )\n", - " > NDArray.scalar_float(Float(1e-05))\n", - ").to_bool()\n", - " -> (\n", - " abs(\n", - " (\n", - " (astype(NDArray.scalar_int(reshape(NDArray.var(\"y\"), TupleInt(Int(-1))).size), NDArray.var(\"X\").dtype) / NDArray.scalar_float(Float(150.0)))\n", - " - NDArray.scalar_float(Float(1.0))\n", - " )\n", - " )\n", - " > NDArray.scalar_float(Float(1e-05))\n", - ").to_bool()\n", - " -> FALSE\n", - "asarray(NDArray.var(\"X\")).shape[Int(1)] < (\n", - " unique_values(concat((TupleNDArray(unique_values(asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))))) + TupleNDArray.EMPTY))).shape[\n", - " Int(0)\n", - " ]\n", - " - Int(1)\n", - ")\n", - " -> NDArray.var(\"X\").shape[Int(1)] < (unique_values(reshape(NDArray.var(\"y\"), TupleInt(Int(-1)))).shape[Int(0)] - Int(1))\n", - " -> FALSE\n", - "(\n", - " unique_values(concat((TupleNDArray(unique_values(asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))))) + TupleNDArray.EMPTY))).shape[\n", - " Int(0)\n", - " ]\n", - " - Int(1)\n", - ") < Int(2)\n", - " -> (unique_values(reshape(NDArray.var(\"y\"), TupleInt(Int(-1)))).shape[Int(0)] - Int(1)) < Int(2)\n", - " -> FALSE\n", - "asarray(NDArray.var(\"X\")).shape.length()\n", - " -> NDArray.var(\"X\").ndim\n", - " -> Int(2)\n", - "unique_inverse(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))).length()\n", - " -> Int(2)\n", - " -> Int(2)\n", - "unique_inverse(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))))[Int(0)].shape[Int(0)]\n", - "unique_inverse(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))))[Int(0)].shape[Int(0)]\n", - " -> unique_values(reshape(NDArray.var(\"y\"), TupleInt(Int(-1)))).shape[Int(0)]\n", - " -> Int(3)\n" - ] - }, - { - "ename": "TypeError", - "evalue": "Cannot convert slice(None, None, None) () to TypeRefWithVars(name='Int', args=())", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:406\u001b[0m, in \u001b[0;36m_special_method\u001b[0;34m(self, __name, *args)\u001b[0m\n\u001b[1;32m 405\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 406\u001b[0m method \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__egg_decls__\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_class_decl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__egg_typed_expr__\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpreserved_methods\u001b[49m\u001b[43m[\u001b[49m\u001b[43m__name\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 407\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m:\n", - "\u001b[0;31mKeyError\u001b[0m: '__setitem__'", - "\nDuring handling of the above exception, another exception occurred:\n", - "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:108\u001b[0m, in \u001b[0;36m_resolve_literal\u001b[0;34m(tp, arg)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 108\u001b[0m fn \u001b[38;5;241m=\u001b[39m \u001b[43mCONVERSIONS\u001b[49m\u001b[43m[\u001b[49m\u001b[43m(\u001b[49m\u001b[43marg_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtp_just\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m:\n", - "\u001b[0;31mKeyError\u001b[0m: (, JustTypeRef(name='Int', args=()))", - "\nDuring handling of the above exception, another exception occurred:\n", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[7], line 21\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m# Add values for the constants\u001b[39;00m\n\u001b[1;32m 8\u001b[0m egraph\u001b[38;5;241m.\u001b[39mregister(\n\u001b[1;32m 9\u001b[0m rewrite(X_arr\u001b[38;5;241m.\u001b[39mdtype, runtime_ruleset)\u001b[38;5;241m.\u001b[39mto(convert(X\u001b[38;5;241m.\u001b[39mdtype, DType)),\n\u001b[1;32m 10\u001b[0m rewrite(y_arr\u001b[38;5;241m.\u001b[39mdtype, runtime_ruleset)\u001b[38;5;241m.\u001b[39mto(convert(y\u001b[38;5;241m.\u001b[39mdtype, DType)),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 17\u001b[0m rewrite(unique_values(y_arr)\u001b[38;5;241m.\u001b[39mshape)\u001b[38;5;241m.\u001b[39mto(TupleInt(Int(\u001b[38;5;241m3\u001b[39m))),\n\u001b[1;32m 18\u001b[0m )\n\u001b[0;32m---> 21\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_arr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_arr\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 23\u001b[0m \u001b[38;5;66;03m# X_obj, y_obj = egraph.save_object(X), egraph.save_object(y)\u001b[39;00m\n\u001b[1;32m 24\u001b[0m \n\u001b[1;32m 25\u001b[0m \u001b[38;5;66;03m# X_arr = NDArray(X_obj)\u001b[39;00m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;66;03m# y_arr = NDArray(y_obj)\u001b[39;00m\n\u001b[1;32m 27\u001b[0m \u001b[38;5;66;03m# TODO: Make index type be a list. Each item in the list can be a slice, an int,\u001b[39;00m\n", - "Cell \u001b[0;32mIn[1], line 15\u001b[0m, in \u001b[0;36mfit\u001b[0;34m(X, y)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(array_api_dispatch\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m):\n\u001b[1;32m 14\u001b[0m lda \u001b[38;5;241m=\u001b[39m LinearDiscriminantAnalysis(n_components\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[0;32m---> 15\u001b[0m X_r2 \u001b[38;5;241m=\u001b[39m \u001b[43mlda\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mtransform(X)\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m X_r2\n\u001b[1;32m 18\u001b[0m target_names \u001b[38;5;241m=\u001b[39m iris\u001b[38;5;241m.\u001b[39mtarget_names\n", - "File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/base.py:1151\u001b[0m, in \u001b[0;36m_fit_context..decorator..wrapper\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1144\u001b[0m estimator\u001b[38;5;241m.\u001b[39m_validate_params()\n\u001b[1;32m 1146\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[1;32m 1147\u001b[0m skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[1;32m 1148\u001b[0m prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[1;32m 1149\u001b[0m )\n\u001b[1;32m 1150\u001b[0m ):\n\u001b[0;32m-> 1151\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfit_method\u001b[49m\u001b[43m(\u001b[49m\u001b[43mestimator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:629\u001b[0m, in \u001b[0;36mLinearDiscriminantAnalysis.fit\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 623\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcovariance_estimator \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 624\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 625\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcovariance estimator \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 626\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mis not supported \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 627\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwith svd solver. Try another solver\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 628\u001b[0m )\n\u001b[0;32m--> 629\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_solve_svd\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 630\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msolver \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlsqr\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 631\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_solve_lstsq(\n\u001b[1;32m 632\u001b[0m X,\n\u001b[1;32m 633\u001b[0m y,\n\u001b[1;32m 634\u001b[0m shrinkage\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mshrinkage,\n\u001b[1;32m 635\u001b[0m covariance_estimator\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcovariance_estimator,\n\u001b[1;32m 636\u001b[0m )\n", - "File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:501\u001b[0m, in \u001b[0;36mLinearDiscriminantAnalysis._solve_svd\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 498\u001b[0m n_samples, n_features \u001b[38;5;241m=\u001b[39m X\u001b[38;5;241m.\u001b[39mshape\n\u001b[1;32m 499\u001b[0m n_classes \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclasses_\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m--> 501\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmeans_ \u001b[38;5;241m=\u001b[39m \u001b[43m_class_means\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 502\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstore_covariance:\n\u001b[1;32m 503\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcovariance_ \u001b[38;5;241m=\u001b[39m _class_cov(X, y, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpriors_)\n", - "File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:121\u001b[0m, in \u001b[0;36m_class_means\u001b[0;34m(X, y)\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[38;5;28mprint\u001b[39m(classes\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m])\n\u001b[1;32m 120\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(classes\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]):\n\u001b[0;32m--> 121\u001b[0m \u001b[43mmeans\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;241m=\u001b[39m xp\u001b[38;5;241m.\u001b[39mmean(X[y \u001b[38;5;241m==\u001b[39m i], axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m 122\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 123\u001b[0m \u001b[38;5;66;03m# TODO: Explore the choice of using bincount + add.at as it seems sub optimal\u001b[39;00m\n\u001b[1;32m 124\u001b[0m \u001b[38;5;66;03m# from a performance-wise\u001b[39;00m\n\u001b[1;32m 125\u001b[0m cnt \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mbincount(y)\n", - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:408\u001b[0m, in \u001b[0;36m_special_method\u001b[0;34m(self, __name, *args)\u001b[0m\n\u001b[1;32m 406\u001b[0m method \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__egg_decls__\u001b[38;5;241m.\u001b[39mget_class_decl(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__egg_typed_expr__\u001b[38;5;241m.\u001b[39mtp\u001b[38;5;241m.\u001b[39mname)\u001b[38;5;241m.\u001b[39mpreserved_methods[__name]\n\u001b[1;32m 407\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m:\n\u001b[0;32m--> 408\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mRuntimeMethod\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m__name\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 409\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 410\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m method(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs)\n", - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:331\u001b[0m, in \u001b[0;36mRuntimeMethod.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 329\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: \u001b[38;5;28mobject\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Optional[RuntimeExpr]:\n\u001b[1;32m 330\u001b[0m args \u001b[38;5;241m=\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__egg_self__, \u001b[38;5;241m*\u001b[39margs)\n\u001b[0;32m--> 331\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__egg_self__\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__egg_decls__\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__egg_callable_ref__\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__egg_fn_decl__\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:247\u001b[0m, in \u001b[0;36m_call\u001b[0;34m(decls, callable_ref, fn_decl, args, kwargs, bound_params)\u001b[0m\n\u001b[1;32m 245\u001b[0m upcasted_args: \u001b[38;5;28mlist\u001b[39m[RuntimeExpr]\n\u001b[1;32m 246\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m fn_decl \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 247\u001b[0m upcasted_args \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 248\u001b[0m _resolve_literal(tp, arg) \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[1;32m 249\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m arg, tp \u001b[38;5;129;01min\u001b[39;00m zip_longest(args, fn_decl\u001b[38;5;241m.\u001b[39marg_types, fillvalue\u001b[38;5;241m=\u001b[39mfn_decl\u001b[38;5;241m.\u001b[39mvar_arg_type)\n\u001b[1;32m 250\u001b[0m ]\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 252\u001b[0m upcasted_args \u001b[38;5;241m=\u001b[39m cast(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlist[RuntimeExpr]\u001b[39m\u001b[38;5;124m\"\u001b[39m, args)\n", - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:248\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 245\u001b[0m upcasted_args: \u001b[38;5;28mlist\u001b[39m[RuntimeExpr]\n\u001b[1;32m 246\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m fn_decl \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 247\u001b[0m upcasted_args \u001b[38;5;241m=\u001b[39m [\n\u001b[0;32m--> 248\u001b[0m \u001b[43m_resolve_literal\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtp\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43marg\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[1;32m 249\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m arg, tp \u001b[38;5;129;01min\u001b[39;00m zip_longest(args, fn_decl\u001b[38;5;241m.\u001b[39marg_types, fillvalue\u001b[38;5;241m=\u001b[39mfn_decl\u001b[38;5;241m.\u001b[39mvar_arg_type)\n\u001b[1;32m 250\u001b[0m ]\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 252\u001b[0m upcasted_args \u001b[38;5;241m=\u001b[39m cast(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlist[RuntimeExpr]\u001b[39m\u001b[38;5;124m\"\u001b[39m, args)\n", - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:111\u001b[0m, in \u001b[0;36m_resolve_literal\u001b[0;34m(tp, arg)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m:\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot convert \u001b[39m\u001b[38;5;132;01m{\u001b[39;00marg\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mrepr\u001b[39m(arg_type)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m) to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtp\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 111\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[43marg\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/exp/ndarray.py:303\u001b[0m, in \u001b[0;36m\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 298\u001b[0m \u001b[38;5;129m@classmethod\u001b[39m\n\u001b[1;32m 299\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mint\u001b[39m(\u001b[38;5;28mcls\u001b[39m, i: Int) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m IndexKey:\n\u001b[1;32m 300\u001b[0m \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\n\u001b[0;32m--> 303\u001b[0m converter(\u001b[38;5;28mtuple\u001b[39m, IndexKey, \u001b[38;5;28;01mlambda\u001b[39;00m x: IndexKey\u001b[38;5;241m.\u001b[39mtuple_int(\u001b[43mconvert\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mTupleInt\u001b[49m\u001b[43m)\u001b[49m))\n\u001b[1;32m 304\u001b[0m converter(\u001b[38;5;28mint\u001b[39m, IndexKey, \u001b[38;5;28;01mlambda\u001b[39;00m x: IndexKey\u001b[38;5;241m.\u001b[39mint(Int(x)))\n\u001b[1;32m 305\u001b[0m converter(Int, IndexKey, \u001b[38;5;28;01mlambda\u001b[39;00m x: IndexKey\u001b[38;5;241m.\u001b[39mint(x))\n", - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:81\u001b[0m, in \u001b[0;36mconvert\u001b[0;34m(source, target)\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 78\u001b[0m \u001b[38;5;124;03mConvert a source object to a target type.\u001b[39;00m\n\u001b[1;32m 79\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 80\u001b[0m target_ref \u001b[38;5;241m=\u001b[39m class_to_ref(target) \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[0;32m---> 81\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m cast(V, \u001b[43m_resolve_literal\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtarget_ref\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto_var\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msource\u001b[49m\u001b[43m)\u001b[49m)\n", - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:111\u001b[0m, in \u001b[0;36m_resolve_literal\u001b[0;34m(tp, arg)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m:\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot convert \u001b[39m\u001b[38;5;132;01m{\u001b[39;00marg\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mrepr\u001b[39m(arg_type)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m) to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtp\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 111\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[43marg\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/exp/ndarray.py:273\u001b[0m, in \u001b[0;36m\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 269\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, i: Int) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Int:\n\u001b[1;32m 270\u001b[0m \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\n\u001b[0;32m--> 273\u001b[0m converter(\u001b[38;5;28mtuple\u001b[39m, TupleInt, \u001b[38;5;28;01mlambda\u001b[39;00m x: TupleInt(convert(x[\u001b[38;5;241m0\u001b[39m], Int)) \u001b[38;5;241m+\u001b[39m \u001b[43mconvert\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mTupleInt\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mif\u001b[39;00m x \u001b[38;5;28;01melse\u001b[39;00m TupleInt\u001b[38;5;241m.\u001b[39mEMPTY)\n\u001b[1;32m 276\u001b[0m \u001b[38;5;129m@egraph\u001b[39m\u001b[38;5;241m.\u001b[39mregister\n\u001b[1;32m 277\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_tuple_int\u001b[39m(ti: TupleInt, ti2: TupleInt, i: Int, i2: Int, k: i64):\n\u001b[1;32m 278\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [\n\u001b[1;32m 279\u001b[0m rewrite(ti \u001b[38;5;241m+\u001b[39m TupleInt\u001b[38;5;241m.\u001b[39mEMPTY)\u001b[38;5;241m.\u001b[39mto(ti),\n\u001b[1;32m 280\u001b[0m rewrite(TupleInt(i)\u001b[38;5;241m.\u001b[39mlength())\u001b[38;5;241m.\u001b[39mto(Int(\u001b[38;5;241m1\u001b[39m)),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 285\u001b[0m rule(eq(i)\u001b[38;5;241m.\u001b[39mto((TupleInt(i2) \u001b[38;5;241m+\u001b[39m ti)[Int(k)]), k \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mthen(union(i)\u001b[38;5;241m.\u001b[39mwith_(ti[Int(k \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m)])),\n\u001b[1;32m 286\u001b[0m ]\n", - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:81\u001b[0m, in \u001b[0;36mconvert\u001b[0;34m(source, target)\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 78\u001b[0m \u001b[38;5;124;03mConvert a source object to a target type.\u001b[39;00m\n\u001b[1;32m 79\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 80\u001b[0m target_ref \u001b[38;5;241m=\u001b[39m class_to_ref(target) \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[0;32m---> 81\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m cast(V, \u001b[43m_resolve_literal\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtarget_ref\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto_var\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msource\u001b[49m\u001b[43m)\u001b[49m)\n", - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:111\u001b[0m, in \u001b[0;36m_resolve_literal\u001b[0;34m(tp, arg)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m:\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot convert \u001b[39m\u001b[38;5;132;01m{\u001b[39;00marg\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mrepr\u001b[39m(arg_type)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m) to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtp\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 111\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[43marg\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/exp/ndarray.py:273\u001b[0m, in \u001b[0;36m\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 269\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, i: Int) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Int:\n\u001b[1;32m 270\u001b[0m \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\n\u001b[0;32m--> 273\u001b[0m converter(\u001b[38;5;28mtuple\u001b[39m, TupleInt, \u001b[38;5;28;01mlambda\u001b[39;00m x: TupleInt(\u001b[43mconvert\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mInt\u001b[49m\u001b[43m)\u001b[49m) \u001b[38;5;241m+\u001b[39m convert(x[\u001b[38;5;241m1\u001b[39m:], TupleInt) \u001b[38;5;28;01mif\u001b[39;00m x \u001b[38;5;28;01melse\u001b[39;00m TupleInt\u001b[38;5;241m.\u001b[39mEMPTY)\n\u001b[1;32m 276\u001b[0m \u001b[38;5;129m@egraph\u001b[39m\u001b[38;5;241m.\u001b[39mregister\n\u001b[1;32m 277\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_tuple_int\u001b[39m(ti: TupleInt, ti2: TupleInt, i: Int, i2: Int, k: i64):\n\u001b[1;32m 278\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [\n\u001b[1;32m 279\u001b[0m rewrite(ti \u001b[38;5;241m+\u001b[39m TupleInt\u001b[38;5;241m.\u001b[39mEMPTY)\u001b[38;5;241m.\u001b[39mto(ti),\n\u001b[1;32m 280\u001b[0m rewrite(TupleInt(i)\u001b[38;5;241m.\u001b[39mlength())\u001b[38;5;241m.\u001b[39mto(Int(\u001b[38;5;241m1\u001b[39m)),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 285\u001b[0m rule(eq(i)\u001b[38;5;241m.\u001b[39mto((TupleInt(i2) \u001b[38;5;241m+\u001b[39m ti)[Int(k)]), k \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mthen(union(i)\u001b[38;5;241m.\u001b[39mwith_(ti[Int(k \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m)])),\n\u001b[1;32m 286\u001b[0m ]\n", - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:81\u001b[0m, in \u001b[0;36mconvert\u001b[0;34m(source, target)\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 78\u001b[0m \u001b[38;5;124;03mConvert a source object to a target type.\u001b[39;00m\n\u001b[1;32m 79\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 80\u001b[0m target_ref \u001b[38;5;241m=\u001b[39m class_to_ref(target) \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[0;32m---> 81\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m cast(V, \u001b[43m_resolve_literal\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtarget_ref\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto_var\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msource\u001b[49m\u001b[43m)\u001b[49m)\n", - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:110\u001b[0m, in \u001b[0;36m_resolve_literal\u001b[0;34m(tp, arg)\u001b[0m\n\u001b[1;32m 108\u001b[0m fn \u001b[38;5;241m=\u001b[39m CONVERSIONS[(arg_type, tp_just)]\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m:\n\u001b[0;32m--> 110\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot convert \u001b[39m\u001b[38;5;132;01m{\u001b[39;00marg\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mrepr\u001b[39m(arg_type)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m) to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtp\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 111\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m fn(arg)\n", - "\u001b[0;31mTypeError\u001b[0m: Cannot convert slice(None, None, None) () to TypeRefWithVars(name='Int', args=())" - ] - } - ], + "outputs": [], "source": [ "from egglog.exp.array_api import *\n", "\n", @@ -336,8 +123,7 @@ "# X_obj, y_obj = egraph.save_object(X), egraph.save_object(y)\n", "\n", "# X_arr = NDArray(X_obj)\n", - "# y_arr = NDArray(y_obj)\n", - "# TODO: Make index type be a list. Each item in the list can be a slice, an int," + "# y_arr = NDArray(y_obj)" ] }, { diff --git a/python/egglog/exp/array_api.py b/python/egglog/exp/array_api.py index 128a741b..021b81c7 100644 --- a/python/egglog/exp/array_api.py +++ b/python/egglog/exp/array_api.py @@ -1,5 +1,3 @@ -# mypy: disable-error-code=empty-body - from __future__ import annotations import itertools @@ -13,6 +11,9 @@ # Pretend that exprs are numbers b/c scikit learn does isinstance checks from egglog.runtime import RuntimeExpr +# mypy: disable-error-code=empty-body + + numbers.Integral.register(RuntimeExpr) egraph = EGraph() @@ -111,7 +112,6 @@ def isdtype(dtype: DType, kind: IsDtypeKind) -> Bool: ... -converter(np.dtype, IsDtypeKind, lambda x: IsDtypeKind.dtype(convert(x, DType))) converter(DType, IsDtypeKind, lambda x: IsDtypeKind.dtype(x)) converter(str, IsDtypeKind, lambda x: IsDtypeKind.string(x)) converter( @@ -286,23 +286,108 @@ def _tuple_int(ti: TupleInt, ti2: TupleInt, i: Int, i2: Int, k: i64): ] -# HANDLED_FUNCTIONS = {} +@egraph.class_ +class OptionalInt(Expr): + none: ClassVar[OptionalInt] + + @classmethod + def some(cls, value: Int) -> OptionalInt: + ... + + +converter(type(None), OptionalInt, lambda x: OptionalInt.none) +converter(Int, OptionalInt, OptionalInt.some) @egraph.class_ -class IndexKey(Expr): +class Slice(Expr): + def __init__(self, start: OptionalInt, stop: OptionalInt, step: OptionalInt) -> None: + ... + + +converter( + slice, + Slice, + lambda x: Slice(convert(x.start, OptionalInt), convert(x.stop, OptionalInt), convert(x.step, OptionalInt)), +) + + +@egraph.class_ +class MultiAxisIndexKeyItem(Expr): + ELLIPSIS: ClassVar[MultiAxisIndexKeyItem] + NONE: ClassVar[MultiAxisIndexKeyItem] + @classmethod - def tuple_int(cls, ti: TupleInt) -> IndexKey: + def int(cls, i: Int) -> MultiAxisIndexKeyItem: ... + @classmethod + def slice(cls, slice: Slice) -> MultiAxisIndexKeyItem: + ... + + +converter(type(...), MultiAxisIndexKeyItem, lambda x: MultiAxisIndexKeyItem.ELLIPSIS) +converter(type(None), MultiAxisIndexKeyItem, lambda x: MultiAxisIndexKeyItem.NONE) +converter(Int, MultiAxisIndexKeyItem, MultiAxisIndexKeyItem.int) +converter(Slice, MultiAxisIndexKeyItem, MultiAxisIndexKeyItem.slice) + + +@egraph.class_ +class MultiAxisIndexKey(Expr): + def __init__(self, item: MultiAxisIndexKeyItem) -> None: + ... + + EMPTY: ClassVar[MultiAxisIndexKey] + + def __add__(self, other: MultiAxisIndexKey) -> MultiAxisIndexKey: + ... + + +converter( + tuple, + MultiAxisIndexKey, + lambda x: MultiAxisIndexKey(convert(x[0], MultiAxisIndexKeyItem)) + convert(x[1:], MultiAxisIndexKey) + if x + else MultiAxisIndexKey.EMPTY, +) + + +@egraph.class_ +class IndexKey(Expr): + """ + A key for indexing into an array + + https://data-apis.org/array-api/2022.12/API_specification/indexing.html + + It is equivalent to the following type signature: + + Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis, None], ...], array] + """ + + ELLIPSIS: ClassVar[IndexKey] + @classmethod def int(cls, i: Int) -> IndexKey: ... + @classmethod + def slice(cls, slice: Slice) -> IndexKey: + ... + + # Disabled until we support late binding + # @classmethod + # def boolean_array(cls, b: NDArray) -> IndexKey: + # ... + + @classmethod + def multi_axis(cls, key: MultiAxisIndexKey) -> IndexKey: + ... + -converter(tuple, IndexKey, lambda x: IndexKey.tuple_int(convert(x, TupleInt))) -converter(int, IndexKey, lambda x: IndexKey.int(Int(x))) -converter(Int, IndexKey, lambda x: IndexKey.int(x)) +converter(type(...), IndexKey, lambda x: IndexKey.ELLIPSIS) +converter(Int, IndexKey, IndexKey.int) +converter(Slice, IndexKey, IndexKey.slice) +converter(MultiAxisIndexKey, IndexKey, IndexKey.multi_axis) @egraph.class_ @@ -400,8 +485,8 @@ def ndarray_index(x: NDArray) -> IndexKey: converter(NDArray, IndexKey, ndarray_index) -converter(float, NDArray, lambda x: NDArray.scalar_float(Float(x))) -converter(int, NDArray, lambda x: NDArray.scalar_int(Int(x))) +converter(Float, NDArray, NDArray.scalar_float) +converter(Int, NDArray, NDArray.scalar_int) @egraph.register @@ -478,7 +563,6 @@ def some(cls, value: Bool) -> OptionalBool: converter(type(None), OptionalBool, lambda x: OptionalBool.none) converter(Bool, OptionalBool, lambda x: OptionalBool.some(x)) -converter(bool, OptionalBool, lambda x: OptionalBool.some(convert(x, Bool))) @egraph.class_ @@ -518,6 +602,7 @@ def some(cls, value: TupleInt) -> OptionalTupleInt: converter(type(None), OptionalTupleInt, lambda x: OptionalTupleInt.none) converter(TupleInt, OptionalTupleInt, lambda x: OptionalTupleInt.some(x)) +# TODO: Don't allow ints to be converted to OptionalTupleInt, and have another type that also unions ints converter(int, OptionalTupleInt, lambda x: OptionalTupleInt.some(TupleInt(Int(x)))) diff --git a/python/egglog/runtime.py b/python/egglog/runtime.py index 82a81ada..be478140 100644 --- a/python/egglog/runtime.py +++ b/python/egglog/runtime.py @@ -70,7 +70,41 @@ def converter(from_type: Type[T], to_type: Type[V], fn: Callable[[T], V]) -> Non to_type_name = process_tp(to_type) if not isinstance(to_type_name, JustTypeRef): raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}") - CONVERSIONS[(process_tp(from_type), to_type_name)] = fn + _register_converter(process_tp(from_type), to_type_name, fn) + + +def _register_converter(a: Type | JustTypeRef, b: JustTypeRef, a_b: Callable) -> None: + """ + Registers a converter from some type to an egglog type, if not already registered. + + Also adds transitive converters, i.e. if registering A->B and there is already B->C, then A->C will be registered. + Also, if registering A->B and there is already D->A, then D->B will be registered. + """ + if a == b or (a, b) in CONVERSIONS: + return + CONVERSIONS[(a, b)] = a_b + for (c, d), c_d in list(CONVERSIONS.items()): + if b == c: + _register_converter(a, d, _ComposedConverter(a_b, c_d)) + if a == d: + _register_converter(c, b, _ComposedConverter(c_d, a_b)) + + +@dataclass +class _ComposedConverter: + """ + A converter which is composed of multiple converters. + + _ComposeConverter(a_b, b_c) is equivalent to lambda x: b_c(a_b(x)) + + We use the dataclass instead of the lambda to make it easier to debug. + """ + + a_b: Callable + b_c: Callable + + def __call__(self, x: object) -> object: + return self.b_c(self.a_b(x)) def convert(source: object, target: type[V]) -> V: diff --git a/python/tests/test_convert.py b/python/tests/test_convert.py index e5790cbb..409f005b 100644 --- a/python/tests/test_convert.py +++ b/python/tests/test_convert.py @@ -1,6 +1,17 @@ +import copy + +import egglog.runtime +import pytest from egglog import * +@pytest.fixture(autouse=True) +def reset_conversions(): + old_conversions = copy.copy(egglog.runtime.CONVERSIONS) + yield + egglog.runtime.CONVERSIONS = old_conversions + + def test_conversion_custom_metaclass(): class MyMeta(type): pass @@ -33,3 +44,70 @@ def __init__(self): converter(MyType, MyTypeExpr, lambda x: MyTypeExpr()) assert expr_parts(convert(MyType(), MyTypeExpr)) == expr_parts(MyTypeExpr()) + + +def test_conversion_transitive_forward(): + egraph = EGraph() + + class MyType: + pass + + @egraph.class_ + class MyTypeExpr(Expr): + def __init__(self): + ... + + @egraph.class_ + class MyTypeExpr2(Expr): + def __init__(self): + ... + + converter(MyType, MyTypeExpr, lambda x: MyTypeExpr()) + converter(MyTypeExpr, MyTypeExpr2, lambda x: MyTypeExpr2()) + + assert expr_parts(convert(MyType(), MyTypeExpr2)) == expr_parts(MyTypeExpr2()) + + +def test_conversion_transitive_backward(): + egraph = EGraph() + + class MyType: + pass + + @egraph.class_ + class MyTypeExpr(Expr): + def __init__(self): + ... + + @egraph.class_ + class MyTypeExpr2(Expr): + def __init__(self): + ... + + converter(MyTypeExpr, MyTypeExpr2, lambda x: MyTypeExpr2()) + converter(MyType, MyTypeExpr, lambda x: MyTypeExpr()) + assert expr_parts(convert(MyType(), MyTypeExpr2)) == expr_parts(MyTypeExpr2()) + + +def test_conversion_transitive_cycle(): + egraph = EGraph() + + class MyType: + pass + + @egraph.class_ + class MyTypeExpr(Expr): + def __init__(self): + ... + + @egraph.class_ + class MyTypeExpr2(Expr): + def __init__(self): + ... + + converter(MyType, MyTypeExpr, lambda x: MyTypeExpr()) + converter(MyTypeExpr, MyTypeExpr2, lambda x: MyTypeExpr2()) + converter(MyTypeExpr2, MyTypeExpr, lambda x: MyTypeExpr()) + + assert expr_parts(convert(MyType(), MyTypeExpr2)) == expr_parts(MyTypeExpr2()) + assert expr_parts(convert(MyType(), MyTypeExpr)) == expr_parts(MyTypeExpr()) From 2c0a9189e75c97c40d9c3712215da8fddf3f137f Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Thu, 10 Aug 2023 09:49:54 -0400 Subject: [PATCH 2/2] Add PR # to changelog --- docs/changelog.md | 2 +- docs/tutorials/array-api.ipynb | 206 ++++++++++++++++++++++++++++++++- 2 files changed, 203 insertions(+), 5 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index b1b173f0..65ae4d5a 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -21,7 +21,7 @@ _This project uses semantic versioning. Before 1.0.0, this means that every brea - Upgraded `egg-smol` dependency ([changes](https://github.com/saulshanabrook/egg-smol/compare/353c4387640019bd2066991ee0488dc6d5c54168...2ac80cb1162c61baef295d8e6d00351bfe84883f)) - Add support for functions which mutates their args, like `__setitem__` [#35](https://github.com/metadsl/egglog-python/pull/35) -- Makes conversions transitive +- Makes conversions transitive [#38](https://github.com/metadsl/egglog-python/pull/38) ## 0.5.1 (2023-07-18) diff --git a/docs/tutorials/array-api.ipynb b/docs/tutorials/array-api.ipynb index 4f0a0c6f..cf0c517a 100644 --- a/docs/tutorials/array-api.ipynb +++ b/docs/tutorials/array-api.ipynb @@ -66,7 +66,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -84,7 +84,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -94,9 +94,207 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "isdtype(DType.float32, IsDtypeKind.string(\"integral\"))\n", + " -> FALSE\n", + " -> FALSE\n", + "DType.float64 == NDArray.var(\"X\").dtype\n", + " -> DType.float64 == NDArray.var(\"X\").dtype\n", + " -> TRUE\n", + "asarray(NDArray.var(\"X\")).ndim == Int(0)\n", + " -> NDArray.var(\"X\").ndim == Int(0)\n", + " -> FALSE\n", + "asarray(NDArray.var(\"X\")).ndim == Int(1)\n", + " -> NDArray.var(\"X\").ndim == Int(1)\n", + " -> FALSE\n", + "asarray(NDArray.var(\"X\")).ndim >= Int(3)\n", + " -> NDArray.var(\"X\").ndim >= Int(3)\n", + " -> FALSE\n", + "asarray(asarray(NDArray.var(\"X\"))).dtype == DType.object\n", + " -> NDArray.var(\"X\").dtype == DType.object\n", + " -> FALSE\n", + "isdtype(asarray(asarray(NDArray.var(\"X\"))).dtype, (IsDtypeKind.string(\"real floating\") | (IsDtypeKind.string(\"complex floating\") | IsDtypeKind.NULL)))\n", + " -> isdtype(NDArray.var(\"X\").dtype, (IsDtypeKind.string(\"real floating\") | IsDtypeKind.string(\"complex floating\")))\n", + " -> TRUE\n", + "isfinite(sum(asarray(asarray(NDArray.var(\"X\"))))).to_bool()\n", + " -> isfinite(sum(NDArray.var(\"X\"))).to_bool()\n", + " -> TRUE\n", + "asarray(NDArray.var(\"X\")).shape.length()\n", + " -> NDArray.var(\"X\").ndim\n", + " -> Int(2)\n", + "asarray(NDArray.var(\"X\")).shape[Int(0)] < Int(2)\n", + " -> NDArray.var(\"X\").shape[Int(0)] < Int(2)\n", + " -> FALSE\n", + "asarray(NDArray.var(\"X\")).ndim == Int(2)\n", + " -> NDArray.var(\"X\").ndim == Int(2)\n", + " -> TRUE\n", + "asarray(NDArray.var(\"X\")).shape[Int(1)] < Int(1)\n", + " -> NDArray.var(\"X\").shape[Int(1)] < Int(1)\n", + " -> FALSE\n", + "asarray(NDArray.var(\"y\")).ndim >= Int(3)\n", + " -> NDArray.var(\"y\").ndim >= Int(3)\n", + " -> FALSE\n", + "asarray(NDArray.var(\"y\")).ndim == Int(2)\n", + " -> NDArray.var(\"y\").ndim == Int(2)\n", + " -> FALSE\n", + "asarray(NDArray.var(\"y\")).shape.length()\n", + " -> NDArray.var(\"y\").ndim\n", + " -> Int(1)\n", + "asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))).dtype == DType.object\n", + " -> NDArray.var(\"y\").dtype == DType.object\n", + " -> FALSE\n", + "isdtype(\n", + " asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))).dtype,\n", + " (IsDtypeKind.string(\"real floating\") | (IsDtypeKind.string(\"complex floating\") | IsDtypeKind.NULL)),\n", + ")\n", + " -> isdtype(NDArray.var(\"y\").dtype, (IsDtypeKind.string(\"real floating\") | IsDtypeKind.string(\"complex floating\")))\n", + " -> FALSE\n", + "asarray(NDArray.var(\"X\")).shape.length()\n", + " -> NDArray.var(\"X\").ndim\n", + " -> Int(2)\n", + "asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))).shape.length()\n", + " -> Int(1)\n", + " -> Int(1)\n", + "asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))).shape[Int(0)] < asarray(NDArray.var(\"X\")).shape[Int(0)]\n", + " -> NDArray.var(\"y\").size < NDArray.var(\"X\").shape[Int(0)]\n", + " -> FALSE\n", + "asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))).shape[Int(0)] > asarray(NDArray.var(\"X\")).shape[Int(0)]\n", + " -> NDArray.var(\"y\").size > NDArray.var(\"X\").shape[Int(0)]\n", + " -> FALSE\n", + "asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))).shape[Int(0)] == asarray(NDArray.var(\"X\")).shape[Int(0)]\n", + " -> NDArray.var(\"y\").size == NDArray.var(\"X\").shape[Int(0)]\n", + " -> TRUE\n", + "asarray(NDArray.var(\"X\")).shape.length()\n", + " -> NDArray.var(\"X\").ndim\n", + " -> Int(2)\n", + "asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))).ndim == Int(2)\n", + " -> FALSE\n", + " -> FALSE\n", + "asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))).ndim == Int(1)\n", + " -> TRUE\n", + " -> TRUE\n", + "asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))).shape.length()\n", + " -> Int(1)\n", + " -> Int(1)\n", + "asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))).shape[Int(0)] == Int(0)\n", + " -> NDArray.var(\"y\").size == Int(0)\n", + " -> FALSE\n", + "asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))).dtype == DType.object\n", + " -> NDArray.var(\"y\").dtype == DType.object\n", + " -> FALSE\n", + "asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))).ndim == Int(2)\n", + " -> FALSE\n", + " -> FALSE\n", + "isdtype(asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))).dtype, IsDtypeKind.string(\"real floating\"))\n", + " -> isdtype(NDArray.var(\"y\").dtype, IsDtypeKind.string(\"real floating\"))\n", + " -> FALSE\n", + "unique_values(asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))))).shape[Int(0)] > Int(2)\n", + " -> unique_values(reshape(NDArray.var(\"y\"), TupleInt(Int(-1)))).shape[Int(0)] > Int(2)\n", + " -> TRUE\n", + "asarray(NDArray.var(\"X\")).shape.length()\n", + " -> NDArray.var(\"X\").ndim\n", + " -> Int(2)\n", + "asarray(NDArray.var(\"X\")).shape[Int(0)] == unique_values(\n", + " concat((TupleNDArray(unique_values(asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))))) + TupleNDArray.EMPTY))\n", + ").shape[Int(0)]\n", + " -> NDArray.var(\"X\").shape[Int(0)] == unique_values(reshape(NDArray.var(\"y\"), TupleInt(Int(-1)))).shape[Int(0)]\n", + " -> FALSE\n", + "unique_counts(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))).length()\n", + " -> Int(2)\n", + " -> Int(2)\n", + "asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))).shape[Int(0)]\n", + " -> NDArray.var(\"y\").size\n", + " -> Int(150)\n", + "any(\n", + " (\n", + " (\n", + " astype(unique_counts(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))))[Int(1)], asarray(NDArray.var(\"X\")).dtype)\n", + " / NDArray.scalar_float(Float(150.0))\n", + " )\n", + " < NDArray.scalar_int(Int(0))\n", + " )\n", + ").to_bool()\n", + " -> FALSE\n", + " -> FALSE\n", + "(\n", + " abs(\n", + " (\n", + " sum(\n", + " (\n", + " astype(unique_counts(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))))[Int(1)], asarray(NDArray.var(\"X\")).dtype)\n", + " / NDArray.scalar_float(Float(150.0))\n", + " )\n", + " )\n", + " - NDArray.scalar_float(Float(1.0))\n", + " )\n", + " )\n", + " > NDArray.scalar_float(Float(1e-05))\n", + ").to_bool()\n", + " -> (\n", + " abs(\n", + " (\n", + " (astype(NDArray.scalar_int(reshape(NDArray.var(\"y\"), TupleInt(Int(-1))).size), NDArray.var(\"X\").dtype) / NDArray.scalar_float(Float(150.0)))\n", + " - NDArray.scalar_float(Float(1.0))\n", + " )\n", + " )\n", + " > NDArray.scalar_float(Float(1e-05))\n", + ").to_bool()\n", + " -> FALSE\n", + "asarray(NDArray.var(\"X\")).shape[Int(1)] < (\n", + " unique_values(concat((TupleNDArray(unique_values(asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))))) + TupleNDArray.EMPTY))).shape[\n", + " Int(0)\n", + " ]\n", + " - Int(1)\n", + ")\n", + " -> NDArray.var(\"X\").shape[Int(1)] < (unique_values(reshape(NDArray.var(\"y\"), TupleInt(Int(-1)))).shape[Int(0)] - Int(1))\n", + " -> FALSE\n", + "(\n", + " unique_values(concat((TupleNDArray(unique_values(asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))))) + TupleNDArray.EMPTY))).shape[\n", + " Int(0)\n", + " ]\n", + " - Int(1)\n", + ") < Int(2)\n", + " -> (unique_values(reshape(NDArray.var(\"y\"), TupleInt(Int(-1)))).shape[Int(0)] - Int(1)) < Int(2)\n", + " -> FALSE\n", + "asarray(NDArray.var(\"X\")).shape.length()\n", + " -> NDArray.var(\"X\").ndim\n", + " -> Int(2)\n", + "unique_inverse(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))).length()\n", + " -> Int(2)\n", + " -> Int(2)\n", + "unique_inverse(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))))[Int(0)].shape[Int(0)]\n", + "unique_inverse(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))))[Int(0)].shape[Int(0)]\n", + " -> unique_values(reshape(NDArray.var(\"y\"), TupleInt(Int(-1)))).shape[Int(0)]\n", + " -> Int(3)\n" + ] + }, + { + "ename": "TypeError", + "evalue": "NDArray has no method __iter__", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:452\u001b[0m, in \u001b[0;36m_preserved_method\u001b[0;34m(self, __name)\u001b[0m\n\u001b[1;32m 451\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 452\u001b[0m method \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__egg_decls__\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_class_decl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__egg_typed_expr__\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpreserved_methods\u001b[49m\u001b[43m[\u001b[49m\u001b[43m__name\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 453\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m:\n", + "\u001b[0;31mKeyError\u001b[0m: '__iter__'", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[4], line 21\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m# Add values for the constants\u001b[39;00m\n\u001b[1;32m 8\u001b[0m egraph\u001b[38;5;241m.\u001b[39mregister(\n\u001b[1;32m 9\u001b[0m rewrite(X_arr\u001b[38;5;241m.\u001b[39mdtype, runtime_ruleset)\u001b[38;5;241m.\u001b[39mto(convert(X\u001b[38;5;241m.\u001b[39mdtype, DType)),\n\u001b[1;32m 10\u001b[0m rewrite(y_arr\u001b[38;5;241m.\u001b[39mdtype, runtime_ruleset)\u001b[38;5;241m.\u001b[39mto(convert(y\u001b[38;5;241m.\u001b[39mdtype, DType)),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 17\u001b[0m rewrite(unique_values(y_arr)\u001b[38;5;241m.\u001b[39mshape)\u001b[38;5;241m.\u001b[39mto(TupleInt(Int(\u001b[38;5;241m3\u001b[39m))),\n\u001b[1;32m 18\u001b[0m )\n\u001b[0;32m---> 21\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_arr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_arr\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 23\u001b[0m \u001b[38;5;66;03m# X_obj, y_obj = egraph.save_object(X), egraph.save_object(y)\u001b[39;00m\n\u001b[1;32m 24\u001b[0m \n\u001b[1;32m 25\u001b[0m \u001b[38;5;66;03m# X_arr = NDArray(X_obj)\u001b[39;00m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;66;03m# y_arr = NDArray(y_obj)\u001b[39;00m\n", + "Cell \u001b[0;32mIn[1], line 15\u001b[0m, in \u001b[0;36mfit\u001b[0;34m(X, y)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(array_api_dispatch\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m):\n\u001b[1;32m 14\u001b[0m lda \u001b[38;5;241m=\u001b[39m LinearDiscriminantAnalysis(n_components\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[0;32m---> 15\u001b[0m X_r2 \u001b[38;5;241m=\u001b[39m \u001b[43mlda\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mtransform(X)\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m X_r2\n\u001b[1;32m 18\u001b[0m target_names \u001b[38;5;241m=\u001b[39m iris\u001b[38;5;241m.\u001b[39mtarget_names\n", + "File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/base.py:1151\u001b[0m, in \u001b[0;36m_fit_context..decorator..wrapper\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1144\u001b[0m estimator\u001b[38;5;241m.\u001b[39m_validate_params()\n\u001b[1;32m 1146\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[1;32m 1147\u001b[0m skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[1;32m 1148\u001b[0m prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[1;32m 1149\u001b[0m )\n\u001b[1;32m 1150\u001b[0m ):\n\u001b[0;32m-> 1151\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfit_method\u001b[49m\u001b[43m(\u001b[49m\u001b[43mestimator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:629\u001b[0m, in \u001b[0;36mLinearDiscriminantAnalysis.fit\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 623\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcovariance_estimator \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 624\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 625\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcovariance estimator \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 626\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mis not supported \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 627\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwith svd solver. Try another solver\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 628\u001b[0m )\n\u001b[0;32m--> 629\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_solve_svd\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 630\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msolver \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlsqr\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 631\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_solve_lstsq(\n\u001b[1;32m 632\u001b[0m X,\n\u001b[1;32m 633\u001b[0m y,\n\u001b[1;32m 634\u001b[0m shrinkage\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mshrinkage,\n\u001b[1;32m 635\u001b[0m covariance_estimator\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcovariance_estimator,\n\u001b[1;32m 636\u001b[0m )\n", + "File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:506\u001b[0m, in \u001b[0;36mLinearDiscriminantAnalysis._solve_svd\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 503\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcovariance_ \u001b[38;5;241m=\u001b[39m _class_cov(X, y, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpriors_)\n\u001b[1;32m 505\u001b[0m Xc \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m--> 506\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m idx, group \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28;43menumerate\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclasses_\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 507\u001b[0m Xg \u001b[38;5;241m=\u001b[39m X[y \u001b[38;5;241m==\u001b[39m group]\n\u001b[1;32m 508\u001b[0m Xc\u001b[38;5;241m.\u001b[39mappend(Xg \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmeans_[idx, :])\n", + "File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:454\u001b[0m, in \u001b[0;36m_preserved_method\u001b[0;34m(self, __name)\u001b[0m\n\u001b[1;32m 452\u001b[0m method \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__egg_decls__\u001b[38;5;241m.\u001b[39mget_class_decl(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__egg_typed_expr__\u001b[38;5;241m.\u001b[39mtp\u001b[38;5;241m.\u001b[39mname)\u001b[38;5;241m.\u001b[39mpreserved_methods[__name]\n\u001b[1;32m 453\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m:\n\u001b[0;32m--> 454\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__egg_typed_expr__\u001b[38;5;241m.\u001b[39mtp\u001b[38;5;241m.\u001b[39mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m has no method \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m__name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 455\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m method(\u001b[38;5;28mself\u001b[39m)\n", + "\u001b[0;31mTypeError\u001b[0m: NDArray has no method __iter__" + ] + } + ], "source": [ "from egglog.exp.array_api import *\n", "\n",