Skip to content

Commit 866647d

Browse files
committed
Move main namespace linear algebra helpers to _aliases.py
1 parent 48d1ae1 commit 866647d

File tree

2 files changed

+39
-35
lines changed

2 files changed

+39
-35
lines changed

array_api_compat/common/_aliases.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from typing import TYPE_CHECKING
88
if TYPE_CHECKING:
9-
from typing import Optional, Tuple, Union, List
9+
from typing import Optional, Sequence, Tuple, Union, List
1010
from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
1111

1212
from typing import NamedTuple
@@ -408,7 +408,43 @@ def trunc(x: ndarray, /, xp, **kwargs) -> ndarray:
408408
return x
409409
return xp.trunc(x, **kwargs)
410410

411+
# linear algebra functions
412+
413+
def matmul(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray:
414+
return xp.matmul(x1, x2, **kwargs)
415+
416+
# Unlike transpose, matrix_transpose only transposes the last two axes.
417+
def matrix_transpose(x: ndarray, /, xp) -> ndarray:
418+
if x.ndim < 2:
419+
raise ValueError("x must be at least 2-dimensional for matrix_transpose")
420+
return xp.swapaxes(x, -1, -2)
421+
422+
def tensordot(x1: ndarray,
423+
x2: ndarray,
424+
/,
425+
xp,
426+
*,
427+
axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2,
428+
**kwargs,
429+
) -> ndarray:
430+
return xp.tensordot(x1, x2, axes=axes, **kwargs)
431+
432+
def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray:
433+
ndim = max(x1.ndim, x2.ndim)
434+
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
435+
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
436+
if x1_shape[axis] != x2_shape[axis]:
437+
raise ValueError("x1 and x2 must have the same size along the given axis")
438+
439+
x1_, x2_ = xp.broadcast_arrays(x1, x2)
440+
x1_ = xp.moveaxis(x1_, axis, -1)
441+
x2_ = xp.moveaxis(x2_, axis, -1)
442+
443+
res = x1_[..., None, :] @ x2_[..., None]
444+
return res[..., 0, 0]
445+
411446
__all__ = ['UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
412447
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
413448
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
414-
'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc']
449+
'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc', 'matmul',
450+
'matrix_transpose', 'tensordot', 'vecdot']

array_api_compat/common/_linalg.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,16 @@
77

88
from numpy.core.numeric import normalize_axis_tuple
99

10+
from .._aliases import matmul, matrix_transpose, tensordot, vecdot
1011
from .._internal import get_xp
1112

1213
# These are in the main NumPy namespace but not in numpy.linalg
1314
def cross(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1, **kwargs) -> ndarray:
1415
return xp.cross(x1, x2, axis=axis, **kwargs)
1516

16-
def matmul(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray:
17-
return xp.matmul(x1, x2, **kwargs)
18-
1917
def outer(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray:
2018
return xp.outer(x1, x2, **kwargs)
2119

22-
def tensordot(x1: ndarray,
23-
x2: ndarray,
24-
/,
25-
xp,
26-
*,
27-
axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2,
28-
**kwargs,
29-
) -> ndarray:
30-
return xp.tensordot(x1, x2, axes=axes, **kwargs)
31-
3220
class EighResult(NamedTuple):
3321
eigenvalues: ndarray
3422
eigenvectors: ndarray
@@ -103,31 +91,11 @@ def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **k
10391
def matrix_norm(x: ndarray, /, xp, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray:
10492
return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)
10593

106-
# Unlike transpose, matrix_transpose only transposes the last two axes.
107-
def matrix_transpose(x: ndarray, /, xp) -> ndarray:
108-
if x.ndim < 2:
109-
raise ValueError("x must be at least 2-dimensional for matrix_transpose")
110-
return xp.swapaxes(x, -1, -2)
111-
11294
# svdvals is not in NumPy (but it is in SciPy). It is equivalent to
11395
# xp.linalg.svd(compute_uv=False).
11496
def svdvals(x: ndarray, /, xp) -> Union[ndarray, Tuple[ndarray, ...]]:
11597
return xp.linalg.svd(x, compute_uv=False)
11698

117-
def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray:
118-
ndim = max(x1.ndim, x2.ndim)
119-
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
120-
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
121-
if x1_shape[axis] != x2_shape[axis]:
122-
raise ValueError("x1 and x2 must have the same size along the given axis")
123-
124-
x1_, x2_ = xp.broadcast_arrays(x1, x2)
125-
x1_ = xp.moveaxis(x1_, axis, -1)
126-
x2_ = xp.moveaxis(x2_, axis, -1)
127-
128-
res = x1_[..., None, :] @ x2_[..., None]
129-
return res[..., 0, 0]
130-
13199
def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> ndarray:
132100
# xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
133101
# when axis=None and the input is 2-D, so to force a vector norm, we make

0 commit comments

Comments
 (0)