|
8 | 8 | from types import ModuleType
|
9 | 9 | from typing import TYPE_CHECKING, Any, cast, overload
|
10 | 10 |
|
11 |
| -from ._utils._compat import ( |
12 |
| - array_namespace, |
13 |
| - is_dask_namespace, |
14 |
| - is_jax_namespace |
15 |
| -) |
| 11 | +from ._utils._compat import array_namespace, is_dask_namespace, is_jax_namespace |
16 | 12 | from ._utils._typing import Array, DType
|
17 | 13 |
|
18 | 14 | if TYPE_CHECKING:
|
| 15 | + # TODO move outside TYPE_CHECKING |
| 16 | + # depends on scikit-learn abandoning Python 3.9 |
19 | 17 | # https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
|
20 | 18 | from typing import ParamSpec, TypeAlias
|
21 | 19 |
|
@@ -72,8 +70,8 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
|
72 | 70 | positional arguments and return either a single NumPy array or generic, or a
|
73 | 71 | tuple or list thereof.
|
74 | 72 |
|
75 |
| - It must be a pure function, i.e. without side effects such as disk output, |
76 |
| - as depending on the backend it may be executed more than once. |
| 73 | + It must be a pure function, i.e. without side effects, as depending on the |
| 74 | + backend it may be executed more than once. |
77 | 75 | *args : Array
|
78 | 76 | One or more Array API compliant arrays. You need to be able to apply
|
79 | 77 | :func:`numpy.asarray` to them to convert them to numpy; read notes below about
|
@@ -225,18 +223,6 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
|
225 | 223 | wrapped = _npfunc_wrapper(func, multi_output, xp)
|
226 | 224 | out = wrapped(*args, **kwargs)
|
227 | 225 |
|
228 |
| - # Output validation |
229 |
| - if len(out) != len(shapes): |
230 |
| - msg = f"func was declared to return {len(shapes)} outputs, got {len(out)}" |
231 |
| - raise ValueError(msg) |
232 |
| - for out_i, shape_i, dtype_i in zip(out, shapes, dtypes, strict=True): |
233 |
| - if out_i.shape != shape_i: |
234 |
| - msg = f"expected shape {shape_i}, got {out_i.shape}" |
235 |
| - raise ValueError(msg) |
236 |
| - if not xp.isdtype(out_i.dtype, dtype_i): |
237 |
| - msg = f"expected dtype {dtype_i}, got {out_i.dtype}" |
238 |
| - raise ValueError(msg) |
239 |
| - |
240 | 226 | return out if multi_output else out[0]
|
241 | 227 |
|
242 | 228 |
|
|
0 commit comments