@@ -344,6 +344,8 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult:
344
344
# Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to
345
345
# np.linalg.svd(compute_uv=False).
346
346
def svdvals (x : Array , / ) -> Union [Array , Tuple [Array , ...]]:
347
+ if x .dtype not in _floating_dtypes :
348
+ raise TypeError ('Only floating-point dtypes are allowed in svdvals' )
347
349
return Array ._new (np .linalg .svd (x ._array , compute_uv = False ))
348
350
349
351
# Note: tensordot is the numpy top-level namespace but not in np.linalg
@@ -364,12 +366,16 @@ def trace(x: Array, /, *, offset: int = 0) -> Array:
364
366
365
367
See its docstring for more information.
366
368
"""
369
+ if x .dtype not in _numeric_dtypes :
370
+ raise TypeError ('Only numeric dtypes are allowed in trace' )
367
371
# Note: trace always operates on the last two axes, whereas np.trace
368
372
# operates on the first two axes by default
369
373
return Array ._new (np .asarray (np .trace (x ._array , offset = offset , axis1 = - 2 , axis2 = - 1 )))
370
374
371
375
# Note: vecdot is not in NumPy
372
376
def vecdot (x1 : Array , x2 : Array , / , * , axis : int = - 1 ) -> Array :
377
+ if x1 .dtype not in _numeric_dtypes or x2 .dtype not in _numeric_dtypes :
378
+ raise TypeError ('Only numeric dtypes are allowed in vecdot' )
373
379
return tensordot (x1 , x2 , axes = ((axis ,), (axis ,)))
374
380
375
381
0 commit comments