|
11 | 11 | def matrix_norm(x: ndarray, /, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray:
|
12 | 12 | return np.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)
|
13 | 13 |
|
| 14 | +# this function is new in the array API spec. Unlike transpose, it only |
| 15 | +# transposes the last two axes. |
| 16 | +def matrix_transpose(x: ndarray, /) -> ndarray: |
| 17 | + if x.ndim < 2: |
| 18 | + raise ValueError("x must be at least 2-dimensional for matrix_transpose") |
| 19 | + return np.swapaxes(x, -1, -2) |
| 20 | + |
| 21 | +# svdvals is not in NumPy (but it is in SciPy). It is equivalent to |
| 22 | +# np.linalg.svd(compute_uv=False). |
| 23 | +def svdvals(x: ndarray, /) -> Union[ndarray, Tuple[ndarray, ...]]: |
| 24 | + return np.linalg.svd(x, compute_uv=False) |
| 25 | + |
| 26 | +# vecdot is not in NumPy |
| 27 | +def vecdot(x1: ndarray, x2: ndarray, /, *, axis: int = -1) -> ndarray: |
| 28 | + ndim = max(x1.ndim, x2.ndim) |
| 29 | + x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) |
| 30 | + x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) |
| 31 | + if x1_shape[axis] != x2_shape[axis]: |
| 32 | + raise ValueError("x1 and x2 must have the same size along the given axis") |
| 33 | + |
| 34 | + x1_, x2_ = np.broadcast_arrays(x1, x2) |
| 35 | + x1_ = np.moveaxis(x1_, axis, -1) |
| 36 | + x2_ = np.moveaxis(x2_, axis, -1) |
| 37 | + |
| 38 | + res = x1_[..., None, :] @ x2_[..., None] |
| 39 | + return res[..., 0, 0] |
| 40 | + |
14 | 41 | def vector_norm(x: ndarray, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> ndarray:
|
15 | 42 | # np.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
|
16 | 43 | # when axis=None and the input is 2-D, so to force a vector norm, we make
|
@@ -52,5 +79,5 @@ def vector_norm(x: ndarray, /, *, axis: Optional[Union[int, Tuple[int, ...]]] =
|
52 | 79 | from numpy import cross, diagonal, matmul, outer, tensordot, trace
|
53 | 80 |
|
54 | 81 | __all__ = linalg_all.copy()
|
55 |
| -__all__ += ['cross', 'diagonal', 'matmul', 'matrix_norm', 'outer', |
56 |
| - 'tensordot', 'trace', 'vector_norm'] |
| 82 | +__all__ += ['cross', 'diagonal', 'matmul', 'matrix_norm', 'matrix_transpose', |
| 83 | + 'outer', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm'] |
0 commit comments