Skip to content

Commit e9b3500

Browse files
add svd
1 parent 073a561 commit e9b3500

File tree

1 file changed

+43
-18
lines changed

1 file changed

+43
-18
lines changed

docs/tutorials/array-api.ipynb

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
},
8585
{
8686
"cell_type": "code",
87-
"execution_count": 9,
87+
"execution_count": 13,
8888
"metadata": {},
8989
"outputs": [
9090
{
@@ -235,33 +235,42 @@
235235
" )\n",
236236
" > NDArray.scalar_float(Float(1e-05))\n",
237237
").bool()\n",
238-
" -> FALSE\n"
238+
" -> FALSE\n",
239+
"asarray(NDArray.var(\"X\")).shape[Int(1)] < (\n",
240+
" unique_values(concat((TupleNDArray(unique_values(asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))))) + TupleNDArray.EMPTY))).shape[\n",
241+
" Int(0)\n",
242+
" ]\n",
243+
" - Int(1)\n",
244+
")\n",
245+
" -> NDArray.var(\"X\").shape[Int(1)] < (unique_values(reshape(NDArray.var(\"y\"), TupleInt(Int(-1)))).shape[Int(0)] - Int(1))\n",
246+
" -> FALSE\n",
247+
"(\n",
248+
" unique_values(concat((TupleNDArray(unique_values(asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))))) + TupleNDArray.EMPTY))).shape[\n",
249+
" Int(0)\n",
250+
" ]\n",
251+
" - Int(1)\n",
252+
") < Int(2)\n",
253+
" -> (unique_values(reshape(NDArray.var(\"y\"), TupleInt(Int(-1)))).shape[Int(0)] - Int(1)) < Int(2)\n",
254+
" -> FALSE\n",
255+
"asarray(NDArray.var(\"X\")).shape.length()\n",
256+
" -> NDArray.var(\"X\").ndim\n",
257+
" -> Int(2)\n"
239258
]
240259
},
241260
{
242261
"ename": "AttributeError",
243-
"evalue": "Class Int does not have method __sub__",
262+
"evalue": "module '__main__' has no attribute 'unique_inverse'",
244263
"output_type": "error",
245264
"traceback": [
246265
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
247-
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
248-
"File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:388\u001b[0m, in \u001b[0;36m_special_method\u001b[0;34m(self, __name, *args)\u001b[0m\n\u001b[1;32m 387\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 388\u001b[0m method \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__egg_decls__\u001b[39m.\u001b[39;49mget_class_decl(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__egg_typed_expr__\u001b[39m.\u001b[39;49mtp\u001b[39m.\u001b[39;49mname)\u001b[39m.\u001b[39;49mpreserved_methods[__name]\n\u001b[1;32m 389\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mKeyError\u001b[39;00m:\n",
249-
"\u001b[0;31mKeyError\u001b[0m: '__sub__'",
250-
"\nDuring handling of the above exception, another exception occurred:\n",
251-
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
252-
"File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:318\u001b[0m, in \u001b[0;36mRuntimeMethod.__post_init__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 317\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 318\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__egg_fn_decl__ \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__egg_decls__\u001b[39m.\u001b[39;49mget_function_decl(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__egg_callable_ref__)\n\u001b[1;32m 319\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mKeyError\u001b[39;00m:\n",
253-
"File \u001b[0;32m~/p/egg-smol-python/python/egglog/declarations.py:192\u001b[0m, in \u001b[0;36mModuleDeclarations.get_function_decl\u001b[0;34m(self, ref)\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[39mpass\u001b[39;00m\n\u001b[0;32m--> 192\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mKeyError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mFunction \u001b[39m\u001b[39m{\u001b[39;00mref\u001b[39m}\u001b[39;00m\u001b[39m not found\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 193\u001b[0m \u001b[39melse\u001b[39;00m:\n",
254-
"\u001b[0;31mKeyError\u001b[0m: \"Function MethodRef(class_name='Int', method_name='__sub__') not found\"",
255-
"\nDuring handling of the above exception, another exception occurred:\n",
256266
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
257-
"Cell \u001b[0;32mIn[9], line 624\u001b[0m\n\u001b[1;32m 610\u001b[0m \u001b[39m# Add values for the constants\u001b[39;00m\n\u001b[1;32m 611\u001b[0m egraph\u001b[39m.\u001b[39mregister(\n\u001b[1;32m 612\u001b[0m rewrite(X_arr\u001b[39m.\u001b[39mdtype, runtime_ruleset)\u001b[39m.\u001b[39mto(convert(X\u001b[39m.\u001b[39mdtype, DType)),\n\u001b[1;32m 613\u001b[0m rewrite(y_arr\u001b[39m.\u001b[39mdtype, runtime_ruleset)\u001b[39m.\u001b[39mto(convert(y\u001b[39m.\u001b[39mdtype, DType)),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 620\u001b[0m rewrite(unique_values(y_arr)\u001b[39m.\u001b[39mshape)\u001b[39m.\u001b[39mto(TupleInt(Int(\u001b[39m3\u001b[39m))),\n\u001b[1;32m 621\u001b[0m )\n\u001b[0;32m--> 624\u001b[0m res \u001b[39m=\u001b[39m fit(X_arr, y_arr)\n\u001b[1;32m 626\u001b[0m \u001b[39m# X_obj, y_obj = egraph.save_object(X), egraph.save_object(y)\u001b[39;00m\n\u001b[1;32m 627\u001b[0m \n\u001b[1;32m 628\u001b[0m \u001b[39m# X_arr = NDArray(X_obj)\u001b[39;00m\n\u001b[1;32m 629\u001b[0m \u001b[39m# y_arr = NDArray(y_obj)\u001b[39;00m\n",
267+
"Cell \u001b[0;32mIn[13], line 640\u001b[0m\n\u001b[1;32m 626\u001b[0m \u001b[39m# Add values for the constants\u001b[39;00m\n\u001b[1;32m 627\u001b[0m egraph\u001b[39m.\u001b[39mregister(\n\u001b[1;32m 628\u001b[0m rewrite(X_arr\u001b[39m.\u001b[39mdtype, runtime_ruleset)\u001b[39m.\u001b[39mto(convert(X\u001b[39m.\u001b[39mdtype, DType)),\n\u001b[1;32m 629\u001b[0m rewrite(y_arr\u001b[39m.\u001b[39mdtype, runtime_ruleset)\u001b[39m.\u001b[39mto(convert(y\u001b[39m.\u001b[39mdtype, DType)),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 636\u001b[0m rewrite(unique_values(y_arr)\u001b[39m.\u001b[39mshape)\u001b[39m.\u001b[39mto(TupleInt(Int(\u001b[39m3\u001b[39m))),\n\u001b[1;32m 637\u001b[0m )\n\u001b[0;32m--> 640\u001b[0m res \u001b[39m=\u001b[39m fit(X_arr, y_arr)\n\u001b[1;32m 642\u001b[0m \u001b[39m# X_obj, y_obj = egraph.save_object(X), egraph.save_object(y)\u001b[39;00m\n\u001b[1;32m 643\u001b[0m \n\u001b[1;32m 644\u001b[0m \u001b[39m# X_arr = NDArray(X_obj)\u001b[39;00m\n\u001b[1;32m 645\u001b[0m \u001b[39m# y_arr = NDArray(y_obj)\u001b[39;00m\n",
258268
"Cell \u001b[0;32mIn[2], line 15\u001b[0m, in \u001b[0;36mfit\u001b[0;34m(X, y)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[39mwith\u001b[39;00m config_context(array_api_dispatch\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m):\n\u001b[1;32m 14\u001b[0m lda \u001b[39m=\u001b[39m LinearDiscriminantAnalysis(n_components\u001b[39m=\u001b[39m\u001b[39m2\u001b[39m)\n\u001b[0;32m---> 15\u001b[0m X_r2 \u001b[39m=\u001b[39m lda\u001b[39m.\u001b[39;49mfit(X, y)\u001b[39m.\u001b[39mtransform(X)\n\u001b[1;32m 16\u001b[0m \u001b[39mreturn\u001b[39;00m X_r2\n\u001b[1;32m 18\u001b[0m target_names \u001b[39m=\u001b[39m iris\u001b[39m.\u001b[39mtarget_names\n",
259269
"File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/base.py:1151\u001b[0m, in \u001b[0;36m_fit_context.<locals>.decorator.<locals>.wrapper\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1144\u001b[0m estimator\u001b[39m.\u001b[39m_validate_params()\n\u001b[1;32m 1146\u001b[0m \u001b[39mwith\u001b[39;00m config_context(\n\u001b[1;32m 1147\u001b[0m skip_parameter_validation\u001b[39m=\u001b[39m(\n\u001b[1;32m 1148\u001b[0m prefer_skip_nested_validation \u001b[39mor\u001b[39;00m global_skip_validation\n\u001b[1;32m 1149\u001b[0m )\n\u001b[1;32m 1150\u001b[0m ):\n\u001b[0;32m-> 1151\u001b[0m \u001b[39mreturn\u001b[39;00m fit_method(estimator, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
260-
"File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:608\u001b[0m, in \u001b[0;36mLinearDiscriminantAnalysis.fit\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 604\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpriors_ \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpriors_ \u001b[39m/\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpriors_\u001b[39m.\u001b[39msum()\n\u001b[1;32m 606\u001b[0m \u001b[39m# Maximum number of components no matter what n_components is\u001b[39;00m\n\u001b[1;32m 607\u001b[0m \u001b[39m# specified:\u001b[39;00m\n\u001b[0;32m--> 608\u001b[0m max_components \u001b[39m=\u001b[39m \u001b[39mmin\u001b[39m(n_classes \u001b[39m-\u001b[39;49m \u001b[39m1\u001b[39;49m, X\u001b[39m.\u001b[39mshape[\u001b[39m1\u001b[39m])\n\u001b[1;32m 610\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_components \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 611\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_max_components \u001b[39m=\u001b[39m max_components\n",
261-
"File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:390\u001b[0m, in \u001b[0;36m_special_method\u001b[0;34m(self, __name, *args)\u001b[0m\n\u001b[1;32m 388\u001b[0m method \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__egg_decls__\u001b[39m.\u001b[39mget_class_decl(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__egg_typed_expr__\u001b[39m.\u001b[39mtp\u001b[39m.\u001b[39mname)\u001b[39m.\u001b[39mpreserved_methods[__name]\n\u001b[1;32m 389\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mKeyError\u001b[39;00m:\n\u001b[0;32m--> 390\u001b[0m \u001b[39mreturn\u001b[39;00m RuntimeMethod(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__egg_decls__, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__egg_typed_expr__, __name)(\u001b[39m*\u001b[39margs)\n\u001b[1;32m 391\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 392\u001b[0m \u001b[39mreturn\u001b[39;00m method(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39margs)\n",
262-
"File \u001b[0;32m<string>:6\u001b[0m, in \u001b[0;36m__init__\u001b[0;34m(self, __egg_decls__, __egg_typed_expr__, __egg_method_name__)\u001b[0m\n",
263-
"File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:320\u001b[0m, in \u001b[0;36mRuntimeMethod.__post_init__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 318\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__egg_fn_decl__ \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__egg_decls__\u001b[39m.\u001b[39mget_function_decl(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__egg_callable_ref__)\n\u001b[1;32m 319\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mKeyError\u001b[39;00m:\n\u001b[0;32m--> 320\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mAttributeError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mClass \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mclass_name\u001b[39m}\u001b[39;00m\u001b[39m does not have method \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__egg_method_name__\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n",
264-
"\u001b[0;31mAttributeError\u001b[0m: Class Int does not have method __sub__"
270+
"File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:628\u001b[0m, in \u001b[0;36mLinearDiscriminantAnalysis.fit\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 622\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcovariance_estimator \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 623\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 624\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mcovariance estimator \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 625\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mis not supported \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 626\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mwith svd solver. Try another solver\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 627\u001b[0m )\n\u001b[0;32m--> 628\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_solve_svd(X, y)\n\u001b[1;32m 629\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39msolver \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mlsqr\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[1;32m 630\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_solve_lstsq(\n\u001b[1;32m 631\u001b[0m X,\n\u001b[1;32m 632\u001b[0m y,\n\u001b[1;32m 633\u001b[0m shrinkage\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mshrinkage,\n\u001b[1;32m 634\u001b[0m covariance_estimator\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcovariance_estimator,\n\u001b[1;32m 635\u001b[0m )\n",
271+
"File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:500\u001b[0m, in \u001b[0;36mLinearDiscriminantAnalysis._solve_svd\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 497\u001b[0m n_samples, n_features \u001b[39m=\u001b[39m X\u001b[39m.\u001b[39mshape\n\u001b[1;32m 498\u001b[0m n_classes \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mclasses_\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m]\n\u001b[0;32m--> 500\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmeans_ \u001b[39m=\u001b[39m _class_means(X, y)\n\u001b[1;32m 501\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstore_covariance:\n\u001b[1;32m 502\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcovariance_ \u001b[39m=\u001b[39m _class_cov(X, y, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpriors_)\n",
272+
"File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:115\u001b[0m, in \u001b[0;36m_class_means\u001b[0;34m(X, y)\u001b[0m\n\u001b[1;32m 99\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"Compute class means.\u001b[39;00m\n\u001b[1;32m 100\u001b[0m \n\u001b[1;32m 101\u001b[0m \u001b[39mParameters\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[39m Class means.\u001b[39;00m\n\u001b[1;32m 113\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 114\u001b[0m xp, is_array_api_compliant \u001b[39m=\u001b[39m get_namespace(X)\n\u001b[0;32m--> 115\u001b[0m classes, y \u001b[39m=\u001b[39m xp\u001b[39m.\u001b[39;49munique_inverse(y)\n\u001b[1;32m 116\u001b[0m means \u001b[39m=\u001b[39m xp\u001b[39m.\u001b[39mzeros((classes\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m], X\u001b[39m.\u001b[39mshape[\u001b[39m1\u001b[39m]), device\u001b[39m=\u001b[39mdevice(X), dtype\u001b[39m=\u001b[39mX\u001b[39m.\u001b[39mdtype)\n\u001b[1;32m 118\u001b[0m \u001b[39mif\u001b[39;00m is_array_api_compliant:\n",
273+
"\u001b[0;31mAttributeError\u001b[0m: module '__main__' has no attribute 'unique_inverse'"
265274
]
266275
}
267276
],
@@ -273,6 +282,7 @@
273282
"from egglog.egraph import Unit\n",
274283
"import numpy as np\n",
275284
"import numbers\n",
285+
"from types import SimpleNamespace\n",
276286
"\n",
277287
"from egglog import *\n",
278288
"\n",
@@ -454,6 +464,7 @@
454464
"\n",
455465
" def __add__(self, other: Int) -> Int:\n",
456466
" ...\n",
467+
" def __sub__(self, other: Int) -> Int: ...\n",
457468
"\n",
458469
" @egraph.method(preserve=True)\n",
459470
" def __int__(self) -> int:\n",
@@ -491,6 +502,7 @@
491502
" yield rule(eq(o).to(Int(j))).then(set_(o.to_py()).to(PyObject.from_int(j)))\n",
492503
"\n",
493504
" yield rewrite(Int(i) + Int(j)).to(Int(i + j))\n",
505+
" yield rewrite(Int(i) - Int(j)).to(Int(i - j))\n",
494506
"\n",
495507
"\n",
496508
"converter(int, Int, lambda x: Int(x))\n",
@@ -838,6 +850,19 @@
838850
" ]\n",
839851
"\n",
840852
"\n",
853+
"linalg = sys.modules[__name__]\n",
854+
"\n",
855+
"@egraph.function\n",
856+
"def svd(x: NDArray) -> TupleNDArray:\n",
857+
" ...\n",
858+
"\n",
859+
"\n",
860+
"@egraph.register\n",
861+
"def _linalg(x: NDArray):\n",
862+
" return [\n",
863+
" rewrite(svd(x).length()).to(Int(3)),\n",
864+
" ]\n",
865+
"\n",
841866
"##\n",
842867
"# Interval analysis\n",
843868
"#\n",

0 commit comments

Comments
 (0)