Skip to content

Commit d8e1a74

Browse files
committed
Properly restrict the input dtypes for the array_api trace, svdvals, and vecdot
Original NumPy Commit: f375d71ca101db9541b1e70476999b574634556d
1 parent d35b040 commit d8e1a74

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

array_api_strict/linalg.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,8 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult:
344344
# Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to
345345
# np.linalg.svd(compute_uv=False).
346346
def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]:
347+
if x.dtype not in _floating_dtypes:
348+
raise TypeError('Only floating-point dtypes are allowed in svdvals')
347349
return Array._new(np.linalg.svd(x._array, compute_uv=False))
348350

349351
# Note: tensordot is the numpy top-level namespace but not in np.linalg
@@ -364,12 +366,16 @@ def trace(x: Array, /, *, offset: int = 0) -> Array:
364366
365367
See its docstring for more information.
366368
"""
369+
if x.dtype not in _numeric_dtypes:
370+
raise TypeError('Only numeric dtypes are allowed in trace')
367371
# Note: trace always operates on the last two axes, whereas np.trace
368372
# operates on the first two axes by default
369373
return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1)))
370374

371375
# Note: vecdot is not in NumPy
372376
def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
377+
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
378+
raise TypeError('Only numeric dtypes are allowed in vecdot')
373379
return tensordot(x1, x2, axes=((axis,), (axis,)))
374380

375381

0 commit comments

Comments
 (0)