|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 |
| -from typing import TYPE_CHECKING |
| 3 | +from typing import TYPE_CHECKING, NamedTuple |
4 | 4 | if TYPE_CHECKING:
|
5 | 5 | from typing import Literal, Optional, Tuple, Union
|
6 | 6 | from numpy import ndarray
|
7 | 7 |
|
8 | 8 | import numpy as np
|
9 | 9 | from numpy.core.numeric import normalize_axis_tuple
|
10 | 10 |
|
| 11 | +class EighResult(NamedTuple): |
| 12 | + eigenvalues: ndarray |
| 13 | + eigenvectors: ndarray |
| 14 | + |
| 15 | +class QRResult(NamedTuple): |
| 16 | + Q: ndarray |
| 17 | + R: ndarray |
| 18 | + |
| 19 | +class SlogdetResult(NamedTuple): |
| 20 | + sign: ndarray |
| 21 | + logabsdet: ndarray |
| 22 | + |
| 23 | +class SVDResult(NamedTuple): |
| 24 | + U: ndarray |
| 25 | + S: ndarray |
| 26 | + Vh: ndarray |
| 27 | + |
| 28 | +# These functions are the same as their NumPy counterparts except they return |
| 29 | +# a namedtuple. |
| 30 | +def eigh(x: ndarray, /) -> EighResult: |
| 31 | + return EighResult(*np.linalg.eigh(x)) |
| 32 | + |
| 33 | +def qr(x: ndarray, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult: |
| 34 | + return QRResult(*np.linalg.qr(x, mode=mode)) |
| 35 | + |
| 36 | +def slogdet(x: ndarray, /) -> SlogdetResult: |
| 37 | + return SlogdetResult(*np.linalg.slogdet(x)) |
| 38 | + |
| 39 | +def svd(x: ndarray, /, *, full_matrices: bool = True) -> SVDResult: |
| 40 | + return SVDResult(*np.linalg.svd(x, full_matrices=full_matrices)) |
| 41 | + |
| 42 | +# This function is not in NumPy. |
11 | 43 | def matrix_norm(x: ndarray, /, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray:
|
12 | 44 | return np.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)
|
13 | 45 |
|
14 |
| -# this function is new in the array API spec. Unlike transpose, it only |
| 46 | +# This function is new in the array API spec. Unlike transpose, it only |
15 | 47 | # transposes the last two axes.
|
16 | 48 | def matrix_transpose(x: ndarray, /) -> ndarray:
|
17 | 49 | if x.ndim < 2:
|
|
0 commit comments