|
7 | 7 |
|
8 | 8 | from numpy.core.numeric import normalize_axis_tuple
|
9 | 9 |
|
| 10 | +from .._aliases import matmul, matrix_transpose, tensordot, vecdot |
10 | 11 | from .._internal import get_xp
|
11 | 12 |
|
12 | 13 | # These are in the main NumPy namespace but not in numpy.linalg
|
13 | 14 | def cross(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1, **kwargs) -> ndarray:
|
14 | 15 | return xp.cross(x1, x2, axis=axis, **kwargs)
|
15 | 16 |
|
16 |
| -def matmul(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: |
17 |
| - return xp.matmul(x1, x2, **kwargs) |
18 |
| - |
19 | 17 | def outer(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray:
|
20 | 18 | return xp.outer(x1, x2, **kwargs)
|
21 | 19 |
|
22 |
| -def tensordot(x1: ndarray, |
23 |
| - x2: ndarray, |
24 |
| - /, |
25 |
| - xp, |
26 |
| - *, |
27 |
| - axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, |
28 |
| - **kwargs, |
29 |
| -) -> ndarray: |
30 |
| - return xp.tensordot(x1, x2, axes=axes, **kwargs) |
31 |
| - |
32 | 20 | class EighResult(NamedTuple):
|
33 | 21 | eigenvalues: ndarray
|
34 | 22 | eigenvectors: ndarray
|
@@ -103,31 +91,11 @@ def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **k
|
103 | 91 | def matrix_norm(x: ndarray, /, xp, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray:
|
104 | 92 | return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)
|
105 | 93 |
|
106 |
| -# Unlike transpose, matrix_transpose only transposes the last two axes. |
107 |
| -def matrix_transpose(x: ndarray, /, xp) -> ndarray: |
108 |
| - if x.ndim < 2: |
109 |
| - raise ValueError("x must be at least 2-dimensional for matrix_transpose") |
110 |
| - return xp.swapaxes(x, -1, -2) |
111 |
| - |
112 | 94 | # svdvals is not in NumPy (but it is in SciPy). It is equivalent to
|
113 | 95 | # xp.linalg.svd(compute_uv=False).
|
114 | 96 | def svdvals(x: ndarray, /, xp) -> Union[ndarray, Tuple[ndarray, ...]]:
|
115 | 97 | return xp.linalg.svd(x, compute_uv=False)
|
116 | 98 |
|
117 |
| -def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: |
118 |
| - ndim = max(x1.ndim, x2.ndim) |
119 |
| - x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) |
120 |
| - x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) |
121 |
| - if x1_shape[axis] != x2_shape[axis]: |
122 |
| - raise ValueError("x1 and x2 must have the same size along the given axis") |
123 |
| - |
124 |
| - x1_, x2_ = xp.broadcast_arrays(x1, x2) |
125 |
| - x1_ = xp.moveaxis(x1_, axis, -1) |
126 |
| - x2_ = xp.moveaxis(x2_, axis, -1) |
127 |
| - |
128 |
| - res = x1_[..., None, :] @ x2_[..., None] |
129 |
| - return res[..., 0, 0] |
130 |
| - |
131 | 99 | def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> ndarray:
|
132 | 100 | # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
|
133 | 101 | # when axis=None and the input is 2-D, so to force a vector norm, we make
|
|
0 commit comments