Skip to content

Commit 814dfb7

Browse files
committed
nits
1 parent 58292d5 commit 814dfb7

File tree

2 files changed

+6
-20
lines changed

2 files changed

+6
-20
lines changed

src/array_api_extra/_lib/_apply.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,12 @@
88
from types import ModuleType
99
from typing import TYPE_CHECKING, Any, cast, overload
1010

11-
from ._utils._compat import (
12-
array_namespace,
13-
is_dask_namespace,
14-
is_jax_namespace
15-
)
11+
from ._utils._compat import array_namespace, is_dask_namespace, is_jax_namespace
1612
from ._utils._typing import Array, DType
1713

1814
if TYPE_CHECKING:
15+
# TODO move outside TYPE_CHECKING
16+
# depends on scikit-learn abandoning Python 3.9
1917
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
2018
from typing import ParamSpec, TypeAlias
2119

@@ -72,8 +70,8 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
7270
positional arguments and return either a single NumPy array or generic, or a
7371
tuple or list thereof.
7472
75-
It must be a pure function, i.e. without side effects such as disk output,
76-
as depending on the backend it may be executed more than once.
73+
It must be a pure function, i.e. without side effects, as depending on the
74+
backend it may be executed more than once.
7775
*args : Array
7876
One or more Array API compliant arrays. You need to be able to apply
7977
:func:`numpy.asarray` to them to convert them to numpy; read notes below about
@@ -225,18 +223,6 @@ def apply_numpy_func( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04
225223
wrapped = _npfunc_wrapper(func, multi_output, xp)
226224
out = wrapped(*args, **kwargs)
227225

228-
# Output validation
229-
if len(out) != len(shapes):
230-
msg = f"func was declared to return {len(shapes)} outputs, got {len(out)}"
231-
raise ValueError(msg)
232-
for out_i, shape_i, dtype_i in zip(out, shapes, dtypes, strict=True):
233-
if out_i.shape != shape_i:
234-
msg = f"expected shape {shape_i}, got {out_i.shape}"
235-
raise ValueError(msg)
236-
if not xp.isdtype(out_i.dtype, dtype_i):
237-
msg = f"expected dtype {dtype_i}, got {out_i.dtype}"
238-
raise ValueError(msg)
239-
240226
return out if multi_output else out[0]
241227

242228

src/array_api_extra/_lib/_utils/_compat.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def is_cupy_namespace(xp: ModuleType, /) -> bool: ...
2323
def is_dask_namespace(xp: ModuleType, /) -> bool: ...
2424
def is_jax_namespace(xp: ModuleType, /) -> bool: ...
2525
def is_numpy_namespace(xp: ModuleType, /) -> bool: ...
26+
def is_pydata_sparse_namespace(xp: ModuleType, /) -> bool: ...
2627
def is_torch_namespace(xp: ModuleType, /) -> bool: ...
2728
def is_jax_array(x: object, /) -> bool: ...
28-
def is_pydata_sparse_namespace(xp: ModuleType, /) -> bool: ...
2929
def is_writeable_array(x: object, /) -> bool: ...
3030
def size(x: Array, /) -> int | None: ...

0 commit comments

Comments
 (0)