@@ -45,12 +45,44 @@ def slogdet(x: ndarray, /) -> SlogdetResult:
45
45
def svd (x : ndarray , / , * , full_matrices : bool = True ) -> SVDResult :
46
46
return SVDResult (* np .linalg .svd (x , full_matrices = full_matrices ))
47
47
48
- # This function is not in NumPy.
48
+ # These functions have additional keyword arguments
49
+
50
+ # The upper keyword argument is new from NumPy
51
+ def cholesky (x : ndarray , / , * , upper : bool = False ) -> ndarray :
52
+ L = np .linalg .cholesky (x )
53
+ if upper :
54
+ return matrix_transpose (L )
55
+ return L
56
+
57
+ # The rtol keyword argument of matrix_rank() and pinv() is new from NumPy.
58
+ # Note that it has a different semantic meaning from tol and rcond.
59
+ def matrix_rank (x : ndarray , / , * , rtol : Optional [Union [float , ndarray ]] = None ) -> ndarray :
60
+ # this is different from np.linalg.matrix_rank, which supports 1
61
+ # dimensional arrays.
62
+ if x .ndim < 2 :
63
+ raise np .linalg .LinAlgError ("1-dimensional array given. Array must be at least two-dimensional" )
64
+ S = np .linalg .svd (x , compute_uv = False )
65
+ if rtol is None :
66
+ tol = S .max (axis = - 1 , keepdims = True ) * max (x .shape [- 2 :]) * np .finfo (S .dtype ).eps
67
+ else :
68
+ # this is different from np.linalg.matrix_rank, which does not
69
+ # multiply the tolerance by the largest singular value.
70
+ tol = S .max (axis = - 1 , keepdims = True )* np .asarray (rtol )[..., np .newaxis ]
71
+ return np .count_nonzero (S > tol , axis = - 1 )
72
+
73
+ def pinv (x : ndarray , / , * , rtol : Optional [Union [float , ndarray ]] = None ) -> ndarray :
74
+ # this is different from np.linalg.pinv, which does not multiply the
75
+ # default tolerance by max(M, N).
76
+ if rtol is None :
77
+ rtol = max (x .shape [- 2 :]) * np .finfo (x .dtype ).eps
78
+ return np .linalg .pinv (x , rcond = rtol )
79
+
80
+ # These functions are new in the array API spec
81
+
49
82
def matrix_norm (x : ndarray , / , * , keepdims : bool = False , ord : Optional [Union [int , float , Literal ['fro' , 'nuc' ]]] = 'fro' ) -> ndarray :
50
83
return np .linalg .norm (x , axis = (- 2 , - 1 ), keepdims = keepdims , ord = ord )
51
84
52
- # This function is new in the array API spec. Unlike transpose, it only
53
- # transposes the last two axes.
85
+ # Unlike transpose, matrix_transpose only transposes the last two axes.
54
86
def matrix_transpose (x : ndarray , / ) -> ndarray :
55
87
if x .ndim < 2 :
56
88
raise ValueError ("x must be at least 2-dimensional for matrix_transpose" )
@@ -61,7 +93,6 @@ def matrix_transpose(x: ndarray, /) -> ndarray:
61
93
def svdvals (x : ndarray , / ) -> Union [ndarray , Tuple [ndarray , ...]]:
62
94
return np .linalg .svd (x , compute_uv = False )
63
95
64
- # vecdot is not in NumPy
65
96
def vecdot (x1 : ndarray , x2 : ndarray , / , * , axis : int = - 1 ) -> ndarray :
66
97
ndim = max (x1 .ndim , x2 .ndim )
67
98
x1_shape = (1 ,)* (ndim - x1 .ndim ) + tuple (x1 .shape )
@@ -111,6 +142,7 @@ def vector_norm(x: ndarray, /, *, axis: Optional[Union[int, Tuple[int, ...]]] =
111
142
return res
112
143
113
144
__all__ = linalg_all .copy ()
114
- __all__ += ['cross' , 'diagonal' , 'matmul' , 'matrix_norm' , 'matrix_transpose' ,
115
- 'outer' , 'svdvals' , 'tensordot' , 'trace' , 'vecdot' , 'vector_norm' ,
116
- 'EighResult' , 'QRResult' , 'SlogdetResult' , 'SVDResult' ]
145
+ __all__ += ['cross' , 'diagonal' , 'matmul' , 'cholesky' , 'matrix_rank' , 'pinv' ,
146
+ 'matrix_norm' , 'matrix_transpose' , 'outer' , 'svdvals' ,
147
+ 'tensordot' , 'trace' , 'vecdot' , 'vector_norm' , 'EighResult' ,
148
+ 'QRResult' , 'SlogdetResult' , 'SVDResult' ]
0 commit comments