Skip to content

Commit 9811f3c

Browse files
committed
Wrap linalg functions that return namedtuples
1 parent cda717e commit 9811f3c

File tree

1 file changed

+34
-2
lines changed

1 file changed

+34
-2
lines changed

numpy_array_api_compat/linalg.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,49 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, NamedTuple
44
if TYPE_CHECKING:
55
from typing import Literal, Optional, Tuple, Union
66
from numpy import ndarray
77

88
import numpy as np
99
from numpy.core.numeric import normalize_axis_tuple
1010

11+
class EighResult(NamedTuple):
12+
eigenvalues: ndarray
13+
eigenvectors: ndarray
14+
15+
class QRResult(NamedTuple):
16+
Q: ndarray
17+
R: ndarray
18+
19+
class SlogdetResult(NamedTuple):
20+
sign: ndarray
21+
logabsdet: ndarray
22+
23+
class SVDResult(NamedTuple):
24+
U: ndarray
25+
S: ndarray
26+
Vh: ndarray
27+
28+
# These functions are the same as their NumPy counterparts except they return
29+
# a namedtuple.
30+
def eigh(x: ndarray, /) -> EighResult:
31+
return EighResult(*np.linalg.eigh(x))
32+
33+
def qr(x: ndarray, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult:
34+
return QRResult(*np.linalg.qr(x, mode=mode))
35+
36+
def slogdet(x: ndarray, /) -> SlogdetResult:
37+
return SlogdetResult(*np.linalg.slogdet(x))
38+
39+
def svd(x: ndarray, /, *, full_matrices: bool = True) -> SVDResult:
40+
return SVDResult(*np.linalg.svd(x, full_matrices=full_matrices))
41+
42+
# This function is not in NumPy.
1143
def matrix_norm(x: ndarray, /, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray:
1244
return np.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)
1345

14-
# this function is new in the array API spec. Unlike transpose, it only
46+
# This function is new in the array API spec. Unlike transpose, it only
1547
# transposes the last two axes.
1648
def matrix_transpose(x: ndarray, /) -> ndarray:
1749
if x.ndim < 2:

0 commit comments

Comments
 (0)