Skip to content

Commit 3c224a5

Browse files
committed
Add matrix_transpose, svdvals, and vecdot
1 parent 7661247 commit 3c224a5

File tree

1 file changed

+29
-2
lines changed

1 file changed

+29
-2
lines changed

numpy_array_api_compat/linalg.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,33 @@
1111
def matrix_norm(x: ndarray, /, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray:
1212
return np.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)
1313

14+
# this function is new in the array API spec. Unlike transpose, it only
15+
# transposes the last two axes.
16+
def matrix_transpose(x: ndarray, /) -> ndarray:
17+
if x.ndim < 2:
18+
raise ValueError("x must be at least 2-dimensional for matrix_transpose")
19+
return np.swapaxes(x, -1, -2)
20+
21+
# svdvals is not in NumPy (but it is in SciPy). It is equivalent to
22+
# np.linalg.svd(compute_uv=False).
23+
def svdvals(x: ndarray, /) -> Union[ndarray, Tuple[ndarray, ...]]:
24+
return np.linalg.svd(x, compute_uv=False)
25+
26+
# vecdot is not in NumPy
27+
def vecdot(x1: ndarray, x2: ndarray, /, *, axis: int = -1) -> ndarray:
28+
ndim = max(x1.ndim, x2.ndim)
29+
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
30+
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
31+
if x1_shape[axis] != x2_shape[axis]:
32+
raise ValueError("x1 and x2 must have the same size along the given axis")
33+
34+
x1_, x2_ = np.broadcast_arrays(x1, x2)
35+
x1_ = np.moveaxis(x1_, axis, -1)
36+
x2_ = np.moveaxis(x2_, axis, -1)
37+
38+
res = x1_[..., None, :] @ x2_[..., None]
39+
return res[..., 0, 0]
40+
1441
def vector_norm(x: ndarray, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> ndarray:
1542
# np.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
1643
# when axis=None and the input is 2-D, so to force a vector norm, we make
@@ -52,5 +79,5 @@ def vector_norm(x: ndarray, /, *, axis: Optional[Union[int, Tuple[int, ...]]] =
5279
from numpy import cross, diagonal, matmul, outer, tensordot, trace
5380

5481
__all__ = linalg_all.copy()
55-
__all__ += ['cross', 'diagonal', 'matmul', 'matrix_norm', 'outer',
56-
'tensordot', 'trace', 'vector_norm']
82+
__all__ += ['cross', 'diagonal', 'matmul', 'matrix_norm', 'matrix_transpose',
83+
'outer', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm']

0 commit comments

Comments
 (0)