|
4 | 4 | if TYPE_CHECKING:
|
5 | 5 | import torch
|
6 | 6 | array = torch.Tensor
|
| 7 | + from torch import dtype as Dtype |
| 8 | + from typing import Optional |
7 | 9 |
|
8 | 10 | from torch.linalg import *
|
9 | 11 |
|
|
12 | 14 | from torch import linalg as torch_linalg
|
13 | 15 | linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')]
|
14 | 16 |
|
15 |
| -# These are implemented in torch but aren't in the linalg namespace |
16 |
| -from torch import outer, trace |
17 |
| -from ._aliases import _fix_promotion, matrix_transpose, tensordot |
| 17 | +# outer is implemented in torch but aren't in the linalg namespace |
| 18 | +from torch import outer |
| 19 | +from ._aliases import _fix_promotion, matrix_transpose, tensordot, sum |
18 | 20 |
|
19 | 21 | # Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
|
20 | 22 | # first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
|
@@ -49,6 +51,11 @@ def solve(x1: array, x2: array, /, **kwargs) -> array:
|
49 | 51 | x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
|
50 | 52 | return torch.linalg.solve(x1, x2, **kwargs)
|
51 | 53 |
|
| 54 | +# torch.trace doesn't support the offset argument and doesn't support stacking |
| 55 | +def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array: |
| 56 | + # Use our wrapped sum to make sure it does upcasting correctly |
| 57 | + return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) |
| 58 | + |
52 | 59 | __all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot',
|
53 | 60 | 'vecdot', 'solve']
|
54 | 61 |
|
|
0 commit comments