Skip to content

Commit c441f33

Browse files
committed
Fix main namespace linalg functions in numpy and cupy
1 parent 85b71de commit c441f33

File tree

6 files changed

+13
-10
lines changed

6 files changed

+13
-10
lines changed

array_api_compat/common/_aliases.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,8 @@ def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray:
444444
return res[..., 0, 0]
445445

446446
__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
447-
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
447+
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
448+
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
448449
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
449450
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
450451
'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc', 'matmul',

array_api_compat/common/_linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from numpy.core.numeric import normalize_axis_tuple
99

10-
from .._aliases import matmul, matrix_transpose, tensordot, vecdot
10+
from ._aliases import matmul, matrix_transpose, tensordot, vecdot
1111
from .._internal import get_xp
1212

1313
# These are in the main NumPy namespace but not in numpy.linalg

array_api_compat/cupy/_aliases.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@
5757
ceil = get_xp(cp)(_aliases.ceil)
5858
floor = get_xp(cp)(_aliases.floor)
5959
trunc = get_xp(cp)(_aliases.trunc)
60+
matmul = get_xp(cp)(_aliases.matmul)
61+
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
62+
tensordot = get_xp(cp)(_aliases.tensordot)
63+
vecdot = get_xp(cp)(_aliases.vecdot)
6064

6165
__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos',
6266
'acosh', 'asin', 'asinh', 'atan', 'atan2',

array_api_compat/cupy/linalg.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,12 @@
1010

1111
from ..common import _linalg
1212
from .._internal import get_xp
13+
from ._aliases import (matmul, matrix_transpose, tensordot, vecdot)
1314

1415
import cupy as cp
1516

1617
cross = get_xp(cp)(_linalg.cross)
17-
matmul = get_xp(cp)(_linalg.matmul)
1818
outer = get_xp(cp)(_linalg.outer)
19-
tensordot = get_xp(cp)(_linalg.tensordot)
2019
EighResult = _linalg.EighResult
2120
QRResult = _linalg.QRResult
2221
SlogdetResult = _linalg.SlogdetResult
@@ -29,9 +28,7 @@
2928
matrix_rank = get_xp(cp)(_linalg.matrix_rank)
3029
pinv = get_xp(cp)(_linalg.pinv)
3130
matrix_norm = get_xp(cp)(_linalg.matrix_norm)
32-
matrix_transpose = get_xp(cp)(_linalg.matrix_transpose)
3331
svdvals = get_xp(cp)(_linalg.svdvals)
34-
vecdot = get_xp(cp)(_linalg.vecdot)
3532
vector_norm = get_xp(cp)(_linalg.vector_norm)
3633
diagonal = get_xp(cp)(_linalg.diagonal)
3734
trace = get_xp(cp)(_linalg.trace)

array_api_compat/numpy/_aliases.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@
5757
ceil = get_xp(np)(_aliases.ceil)
5858
floor = get_xp(np)(_aliases.floor)
5959
trunc = get_xp(np)(_aliases.trunc)
60+
matmul = get_xp(np)(_aliases.matmul)
61+
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
62+
tensordot = get_xp(np)(_aliases.tensordot)
63+
vecdot = get_xp(np)(_aliases.vecdot)
6064

6165
__all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'acos',
6266
'acosh', 'asin', 'asinh', 'atan', 'atan2',

array_api_compat/numpy/linalg.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33

44
from ..common import _linalg
55
from .._internal import get_xp
6+
from ._aliases import (matmul, matrix_transpose, tensordot, vecdot)
67

78
import numpy as np
89

910
cross = get_xp(np)(_linalg.cross)
10-
matmul = get_xp(np)(_linalg.matmul)
1111
outer = get_xp(np)(_linalg.outer)
12-
tensordot = get_xp(np)(_linalg.tensordot)
1312
EighResult = _linalg.EighResult
1413
QRResult = _linalg.QRResult
1514
SlogdetResult = _linalg.SlogdetResult
@@ -22,9 +21,7 @@
2221
matrix_rank = get_xp(np)(_linalg.matrix_rank)
2322
pinv = get_xp(np)(_linalg.pinv)
2423
matrix_norm = get_xp(np)(_linalg.matrix_norm)
25-
matrix_transpose = get_xp(np)(_linalg.matrix_transpose)
2624
svdvals = get_xp(np)(_linalg.svdvals)
27-
vecdot = get_xp(np)(_linalg.vecdot)
2825
vector_norm = get_xp(np)(_linalg.vector_norm)
2926
diagonal = get_xp(np)(_linalg.diagonal)
3027
trace = get_xp(np)(_linalg.trace)

0 commit comments

Comments
 (0)