Skip to content

Commit 191d026

Browse files
Add device support
1 parent e9b3500 commit 191d026

File tree

1 file changed

+83
-13
lines changed

1 file changed

+83
-13
lines changed

docs/tutorials/array-api.ipynb

Lines changed: 83 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
},
2828
{
2929
"cell_type": "code",
30-
"execution_count": 2,
30+
"execution_count": 1,
3131
"metadata": {},
3232
"outputs": [],
3333
"source": [
@@ -66,7 +66,7 @@
6666
},
6767
{
6868
"cell_type": "code",
69-
"execution_count": 3,
69+
"execution_count": 2,
7070
"metadata": {},
7171
"outputs": [],
7272
"source": [
@@ -84,7 +84,7 @@
8484
},
8585
{
8686
"cell_type": "code",
87-
"execution_count": 13,
87+
"execution_count": 3,
8888
"metadata": {},
8989
"outputs": [
9090
{
@@ -254,23 +254,27 @@
254254
" -> FALSE\n",
255255
"asarray(NDArray.var(\"X\")).shape.length()\n",
256256
" -> NDArray.var(\"X\").ndim\n",
257-
" -> Int(2)\n"
257+
" -> Int(2)\n",
258+
"unique_inverse(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))).length()\n",
259+
" -> Int(2)\n",
260+
" -> Int(2)\n",
261+
"unique_inverse(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))))[Int(0)].shape[Int(0)]\n"
258262
]
259263
},
260264
{
261-
"ename": "AttributeError",
262-
"evalue": "module '__main__' has no attribute 'unique_inverse'",
265+
"ename": "TypeError",
266+
"evalue": "'RuntimeExpr' object cannot be interpreted as an integer",
263267
"output_type": "error",
264268
"traceback": [
265269
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
266-
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
267-
"Cell \u001b[0;32mIn[13], line 640\u001b[0m\n\u001b[1;32m 626\u001b[0m \u001b[39m# Add values for the constants\u001b[39;00m\n\u001b[1;32m 627\u001b[0m egraph\u001b[39m.\u001b[39mregister(\n\u001b[1;32m 628\u001b[0m rewrite(X_arr\u001b[39m.\u001b[39mdtype, runtime_ruleset)\u001b[39m.\u001b[39mto(convert(X\u001b[39m.\u001b[39mdtype, DType)),\n\u001b[1;32m 629\u001b[0m rewrite(y_arr\u001b[39m.\u001b[39mdtype, runtime_ruleset)\u001b[39m.\u001b[39mto(convert(y\u001b[39m.\u001b[39mdtype, DType)),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 636\u001b[0m rewrite(unique_values(y_arr)\u001b[39m.\u001b[39mshape)\u001b[39m.\u001b[39mto(TupleInt(Int(\u001b[39m3\u001b[39m))),\n\u001b[1;32m 637\u001b[0m )\n\u001b[0;32m--> 640\u001b[0m res \u001b[39m=\u001b[39m fit(X_arr, y_arr)\n\u001b[1;32m 642\u001b[0m \u001b[39m# X_obj, y_obj = egraph.save_object(X), egraph.save_object(y)\u001b[39;00m\n\u001b[1;32m 643\u001b[0m \n\u001b[1;32m 644\u001b[0m \u001b[39m# X_arr = NDArray(X_obj)\u001b[39;00m\n\u001b[1;32m 645\u001b[0m \u001b[39m# y_arr = NDArray(y_obj)\u001b[39;00m\n",
268-
"Cell \u001b[0;32mIn[2], line 15\u001b[0m, in \u001b[0;36mfit\u001b[0;34m(X, y)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[39mwith\u001b[39;00m config_context(array_api_dispatch\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m):\n\u001b[1;32m 14\u001b[0m lda \u001b[39m=\u001b[39m LinearDiscriminantAnalysis(n_components\u001b[39m=\u001b[39m\u001b[39m2\u001b[39m)\n\u001b[0;32m---> 15\u001b[0m X_r2 \u001b[39m=\u001b[39m lda\u001b[39m.\u001b[39;49mfit(X, y)\u001b[39m.\u001b[39mtransform(X)\n\u001b[1;32m 16\u001b[0m \u001b[39mreturn\u001b[39;00m X_r2\n\u001b[1;32m 18\u001b[0m target_names \u001b[39m=\u001b[39m iris\u001b[39m.\u001b[39mtarget_names\n",
270+
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
271+
"Cell \u001b[0;32mIn[3], line 676\u001b[0m\n\u001b[1;32m 662\u001b[0m \u001b[39m# Add values for the constants\u001b[39;00m\n\u001b[1;32m 663\u001b[0m egraph\u001b[39m.\u001b[39mregister(\n\u001b[1;32m 664\u001b[0m rewrite(X_arr\u001b[39m.\u001b[39mdtype, runtime_ruleset)\u001b[39m.\u001b[39mto(convert(X\u001b[39m.\u001b[39mdtype, DType)),\n\u001b[1;32m 665\u001b[0m rewrite(y_arr\u001b[39m.\u001b[39mdtype, runtime_ruleset)\u001b[39m.\u001b[39mto(convert(y\u001b[39m.\u001b[39mdtype, DType)),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 672\u001b[0m rewrite(unique_values(y_arr)\u001b[39m.\u001b[39mshape)\u001b[39m.\u001b[39mto(TupleInt(Int(\u001b[39m3\u001b[39m))),\n\u001b[1;32m 673\u001b[0m )\n\u001b[0;32m--> 676\u001b[0m res \u001b[39m=\u001b[39m fit(X_arr, y_arr)\n\u001b[1;32m 678\u001b[0m \u001b[39m# X_obj, y_obj = egraph.save_object(X), egraph.save_object(y)\u001b[39;00m\n\u001b[1;32m 679\u001b[0m \n\u001b[1;32m 680\u001b[0m \u001b[39m# X_arr = NDArray(X_obj)\u001b[39;00m\n\u001b[1;32m 681\u001b[0m \u001b[39m# y_arr = NDArray(y_obj)\u001b[39;00m\n",
272+
"Cell \u001b[0;32mIn[1], line 15\u001b[0m, in \u001b[0;36mfit\u001b[0;34m(X, y)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[39mwith\u001b[39;00m config_context(array_api_dispatch\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m):\n\u001b[1;32m 14\u001b[0m lda \u001b[39m=\u001b[39m LinearDiscriminantAnalysis(n_components\u001b[39m=\u001b[39m\u001b[39m2\u001b[39m)\n\u001b[0;32m---> 15\u001b[0m X_r2 \u001b[39m=\u001b[39m lda\u001b[39m.\u001b[39;49mfit(X, y)\u001b[39m.\u001b[39mtransform(X)\n\u001b[1;32m 16\u001b[0m \u001b[39mreturn\u001b[39;00m X_r2\n\u001b[1;32m 18\u001b[0m target_names \u001b[39m=\u001b[39m iris\u001b[39m.\u001b[39mtarget_names\n",
269273
"File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/base.py:1151\u001b[0m, in \u001b[0;36m_fit_context.<locals>.decorator.<locals>.wrapper\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1144\u001b[0m estimator\u001b[39m.\u001b[39m_validate_params()\n\u001b[1;32m 1146\u001b[0m \u001b[39mwith\u001b[39;00m config_context(\n\u001b[1;32m 1147\u001b[0m skip_parameter_validation\u001b[39m=\u001b[39m(\n\u001b[1;32m 1148\u001b[0m prefer_skip_nested_validation \u001b[39mor\u001b[39;00m global_skip_validation\n\u001b[1;32m 1149\u001b[0m )\n\u001b[1;32m 1150\u001b[0m ):\n\u001b[0;32m-> 1151\u001b[0m \u001b[39mreturn\u001b[39;00m fit_method(estimator, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
270-
"File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:628\u001b[0m, in \u001b[0;36mLinearDiscriminantAnalysis.fit\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 622\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcovariance_estimator \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 623\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 624\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mcovariance estimator \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 625\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mis not supported \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 626\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mwith svd solver. Try another solver\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 627\u001b[0m )\n\u001b[0;32m--> 628\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_solve_svd(X, y)\n\u001b[1;32m 629\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39msolver \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mlsqr\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[1;32m 630\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_solve_lstsq(\n\u001b[1;32m 631\u001b[0m X,\n\u001b[1;32m 632\u001b[0m y,\n\u001b[1;32m 633\u001b[0m shrinkage\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mshrinkage,\n\u001b[1;32m 634\u001b[0m covariance_estimator\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcovariance_estimator,\n\u001b[1;32m 635\u001b[0m )\n",
271-
"File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:500\u001b[0m, in \u001b[0;36mLinearDiscriminantAnalysis._solve_svd\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 497\u001b[0m n_samples, n_features \u001b[39m=\u001b[39m X\u001b[39m.\u001b[39mshape\n\u001b[1;32m 498\u001b[0m n_classes \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mclasses_\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m]\n\u001b[0;32m--> 500\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmeans_ \u001b[39m=\u001b[39m _class_means(X, y)\n\u001b[1;32m 501\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstore_covariance:\n\u001b[1;32m 502\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcovariance_ \u001b[39m=\u001b[39m _class_cov(X, y, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpriors_)\n",
272-
"File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:115\u001b[0m, in \u001b[0;36m_class_means\u001b[0;34m(X, y)\u001b[0m\n\u001b[1;32m 99\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"Compute class means.\u001b[39;00m\n\u001b[1;32m 100\u001b[0m \n\u001b[1;32m 101\u001b[0m \u001b[39mParameters\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[39m Class means.\u001b[39;00m\n\u001b[1;32m 113\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 114\u001b[0m xp, is_array_api_compliant \u001b[39m=\u001b[39m get_namespace(X)\n\u001b[0;32m--> 115\u001b[0m classes, y \u001b[39m=\u001b[39m xp\u001b[39m.\u001b[39;49munique_inverse(y)\n\u001b[1;32m 116\u001b[0m means \u001b[39m=\u001b[39m xp\u001b[39m.\u001b[39mzeros((classes\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m], X\u001b[39m.\u001b[39mshape[\u001b[39m1\u001b[39m]), device\u001b[39m=\u001b[39mdevice(X), dtype\u001b[39m=\u001b[39mX\u001b[39m.\u001b[39mdtype)\n\u001b[1;32m 118\u001b[0m \u001b[39mif\u001b[39;00m is_array_api_compliant:\n",
273-
"\u001b[0;31mAttributeError\u001b[0m: module '__main__' has no attribute 'unique_inverse'"
274+
"File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:629\u001b[0m, in \u001b[0;36mLinearDiscriminantAnalysis.fit\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 623\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcovariance_estimator \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 624\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 625\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mcovariance estimator \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 626\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mis not supported \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 627\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mwith svd solver. Try another solver\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 628\u001b[0m )\n\u001b[0;32m--> 629\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_solve_svd(X, y)\n\u001b[1;32m 630\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39msolver \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mlsqr\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[1;32m 631\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_solve_lstsq(\n\u001b[1;32m 632\u001b[0m X,\n\u001b[1;32m 633\u001b[0m y,\n\u001b[1;32m 634\u001b[0m shrinkage\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mshrinkage,\n\u001b[1;32m 635\u001b[0m covariance_estimator\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcovariance_estimator,\n\u001b[1;32m 636\u001b[0m )\n",
275+
"File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:501\u001b[0m, in \u001b[0;36mLinearDiscriminantAnalysis._solve_svd\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 498\u001b[0m n_samples, n_features \u001b[39m=\u001b[39m X\u001b[39m.\u001b[39mshape\n\u001b[1;32m 499\u001b[0m n_classes \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mclasses_\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m]\n\u001b[0;32m--> 501\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmeans_ \u001b[39m=\u001b[39m _class_means(X, y)\n\u001b[1;32m 502\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstore_covariance:\n\u001b[1;32m 503\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcovariance_ \u001b[39m=\u001b[39m _class_cov(X, y, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpriors_)\n",
276+
"File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:120\u001b[0m, in \u001b[0;36m_class_means\u001b[0;34m(X, y)\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[39mif\u001b[39;00m is_array_api_compliant:\n\u001b[1;32m 119\u001b[0m \u001b[39mprint\u001b[39m(classes\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m])\n\u001b[0;32m--> 120\u001b[0m \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39;49m(classes\u001b[39m.\u001b[39;49mshape[\u001b[39m0\u001b[39;49m]):\n\u001b[1;32m 121\u001b[0m means[i, :] \u001b[39m=\u001b[39m xp\u001b[39m.\u001b[39mmean(X[y \u001b[39m==\u001b[39m i], axis\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m)\n\u001b[1;32m 122\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 123\u001b[0m \u001b[39m# TODO: Explore the choice of using bincount + add.at as it seems sub optimal\u001b[39;00m\n\u001b[1;32m 124\u001b[0m \u001b[39m# from a performance-wise\u001b[39;00m\n",
277+
"\u001b[0;31mTypeError\u001b[0m: 'RuntimeExpr' object cannot be interpreted as an integer"
274278
]
275279
}
276280
],
@@ -575,6 +579,10 @@
575579
"\n",
576580
"\n",
577581
"@egraph.class_\n",
582+
"class Device(Expr): ...\n",
583+
"\n",
584+
"\n",
585+
"@egraph.class_\n",
578586
"class NDArray(Expr):\n",
579587
" def __init__(self, py_array: PyObject) -> None:\n",
580588
" ...\n",
@@ -597,6 +605,11 @@
597605
" ...\n",
598606
"\n",
599607
" @property\n",
608+
" def device(self) -> Device:\n",
609+
" ...\n",
610+
"\n",
611+
"\n",
612+
" @property\n",
600613
" def shape(self) -> TupleInt:\n",
601614
" ...\n",
602615
"\n",
@@ -733,6 +746,20 @@
733746
"converter(type(None), OptionalDType, lambda x: OptionalDType.none)\n",
734747
"converter(DType, OptionalDType, lambda x: OptionalDType.some(x))\n",
735748
"\n",
749+
"@egraph.class_\n",
750+
"class OptionalDevice(Expr):\n",
751+
" none: ClassVar[OptionalDevice]\n",
752+
"\n",
753+
" @classmethod\n",
754+
" def some(cls, value: Device) -> OptionalDevice:\n",
755+
" ...\n",
756+
"\n",
757+
"\n",
758+
"converter(type(None), OptionalDevice, lambda x: OptionalDevice.none)\n",
759+
"converter(Device, OptionalDevice, lambda x: OptionalDevice.some(x))\n",
760+
"\n",
761+
"\n",
762+
"\n",
736763
"\n",
737764
"@egraph.function\n",
738765
"def asarray(a: NDArray, dtype: OptionalDType = OptionalDType.none, copy: OptionalBool = OptionalBool.none) -> NDArray:\n",
@@ -849,6 +876,19 @@
849876
" rewrite(abs(NDArray.scalar_float(f))).to(NDArray.scalar_float(f)),\n",
850877
" ]\n",
851878
"\n",
879+
"@egraph.function\n",
880+
"def unique_inverse(x: NDArray) -> TupleNDArray:\n",
881+
" ...\n",
882+
"\n",
883+
"@egraph.register\n",
884+
"def _unique_inverse(x: NDArray):\n",
885+
" return [\n",
886+
" rewrite(unique_inverse(x).length()).to(Int(2)),\n",
887+
" ]\n",
888+
"\n",
889+
"@egraph.function\n",
890+
"def zeros(shape: TupleInt, dtype: OptionalDType = OptionalDType.none, device: OptionalDevice = OptionalDevice.none) -> NDArray:\n",
891+
" ...\n",
852892
"\n",
853893
"linalg = sys.modules[__name__]\n",
854894
"\n",
@@ -922,6 +962,36 @@
922962
"# y_arr = NDArray(y_obj)"
923963
]
924964
},
965+
{
966+
"cell_type": "code",
967+
"execution_count": 4,
968+
"metadata": {},
969+
"outputs": [],
970+
"source": [
971+
"x = unique_inverse(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))))[Int(0)].shape[Int(0)]\n"
972+
]
973+
},
974+
{
975+
"cell_type": "code",
976+
"execution_count": 9,
977+
"metadata": {},
978+
"outputs": [
979+
{
980+
"ename": "TypeError",
981+
"evalue": "'RuntimeExpr' object cannot be interpreted as an integer",
982+
"output_type": "error",
983+
"traceback": [
984+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
985+
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
986+
"Cell \u001b[0;32mIn[9], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[39mrange\u001b[39;49m(Int(\u001b[39m10\u001b[39;49m))\n",
987+
"\u001b[0;31mTypeError\u001b[0m: 'RuntimeExpr' object cannot be interpreted as an integer"
988+
]
989+
}
990+
],
991+
"source": [
992+
"range(Int(10))"
993+
]
994+
},
925995
{
926996
"cell_type": "code",
927997
"execution_count": null,

0 commit comments

Comments
 (0)