Skip to content

Commit 547f007

Browse files
committed
Add cholesky(), matrix_rank(), and pinv() wrappers
1 parent e7ac55b commit 547f007

File tree

1 file changed

+39
-7
lines changed

1 file changed

+39
-7
lines changed

numpy_array_api_compat/linalg.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,44 @@ def slogdet(x: ndarray, /) -> SlogdetResult:
4545
def svd(x: ndarray, /, *, full_matrices: bool = True) -> SVDResult:
4646
return SVDResult(*np.linalg.svd(x, full_matrices=full_matrices))
4747

48-
# This function is not in NumPy.
48+
# These functions have additional keyword arguments
49+
50+
# The upper keyword argument is new from NumPy
51+
def cholesky(x: ndarray, /, *, upper: bool = False) -> ndarray:
52+
L = np.linalg.cholesky(x)
53+
if upper:
54+
return matrix_transpose(L)
55+
return L
56+
57+
# The rtol keyword argument of matrix_rank() and pinv() is new from NumPy.
58+
# Note that it has a different semantic meaning from tol and rcond.
59+
def matrix_rank(x: ndarray, /, *, rtol: Optional[Union[float, ndarray]] = None) -> ndarray:
60+
# this is different from np.linalg.matrix_rank, which supports 1
61+
# dimensional arrays.
62+
if x.ndim < 2:
63+
raise np.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
64+
S = np.linalg.svd(x, compute_uv=False)
65+
if rtol is None:
66+
tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * np.finfo(S.dtype).eps
67+
else:
68+
# this is different from np.linalg.matrix_rank, which does not
69+
# multiply the tolerance by the largest singular value.
70+
tol = S.max(axis=-1, keepdims=True)*np.asarray(rtol)[..., np.newaxis]
71+
return np.count_nonzero(S > tol, axis=-1)
72+
73+
def pinv(x: ndarray, /, *, rtol: Optional[Union[float, ndarray]] = None) -> ndarray:
74+
# this is different from np.linalg.pinv, which does not multiply the
75+
# default tolerance by max(M, N).
76+
if rtol is None:
77+
rtol = max(x.shape[-2:]) * np.finfo(x.dtype).eps
78+
return np.linalg.pinv(x, rcond=rtol)
79+
80+
# These functions are new in the array API spec
81+
4982
def matrix_norm(x: ndarray, /, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray:
5083
return np.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)
5184

52-
# This function is new in the array API spec. Unlike transpose, it only
53-
# transposes the last two axes.
85+
# Unlike transpose, matrix_transpose only transposes the last two axes.
5486
def matrix_transpose(x: ndarray, /) -> ndarray:
5587
if x.ndim < 2:
5688
raise ValueError("x must be at least 2-dimensional for matrix_transpose")
@@ -61,7 +93,6 @@ def matrix_transpose(x: ndarray, /) -> ndarray:
6193
def svdvals(x: ndarray, /) -> Union[ndarray, Tuple[ndarray, ...]]:
6294
return np.linalg.svd(x, compute_uv=False)
6395

64-
# vecdot is not in NumPy
6596
def vecdot(x1: ndarray, x2: ndarray, /, *, axis: int = -1) -> ndarray:
6697
ndim = max(x1.ndim, x2.ndim)
6798
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
@@ -111,6 +142,7 @@ def vector_norm(x: ndarray, /, *, axis: Optional[Union[int, Tuple[int, ...]]] =
111142
return res
112143

113144
__all__ = linalg_all.copy()
114-
__all__ += ['cross', 'diagonal', 'matmul', 'matrix_norm', 'matrix_transpose',
115-
'outer', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm',
116-
'EighResult', 'QRResult', 'SlogdetResult', 'SVDResult']
145+
__all__ += ['cross', 'diagonal', 'matmul', 'cholesky', 'matrix_rank', 'pinv',
146+
'matrix_norm', 'matrix_transpose', 'outer', 'svdvals',
147+
'tensordot', 'trace', 'vecdot', 'vector_norm', 'EighResult',
148+
'QRResult', 'SlogdetResult', 'SVDResult']

0 commit comments

Comments
 (0)