Skip to content

Commit 486ca51

Browse files
committed
Add a torch wrapper for trace
torch.trace doesn't support stacking or the outer argument.
1 parent 87431b7 commit 486ca51

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

array_api_compat/torch/linalg.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
if TYPE_CHECKING:
55
import torch
66
array = torch.Tensor
7+
from torch import dtype as Dtype
8+
from typing import Optional
79

810
from torch.linalg import *
911

@@ -12,9 +14,9 @@
1214
from torch import linalg as torch_linalg
1315
linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')]
1416

15-
# These are implemented in torch but aren't in the linalg namespace
16-
from torch import outer, trace
17-
from ._aliases import _fix_promotion, matrix_transpose, tensordot
17+
# outer is implemented in torch but aren't in the linalg namespace
18+
from torch import outer
19+
from ._aliases import _fix_promotion, matrix_transpose, tensordot, sum
1820

1921
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
2022
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
@@ -49,6 +51,11 @@ def solve(x1: array, x2: array, /, **kwargs) -> array:
4951
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
5052
return torch.linalg.solve(x1, x2, **kwargs)
5153

54+
# torch.trace doesn't support the offset argument and doesn't support stacking
55+
def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array:
56+
# Use our wrapped sum to make sure it does upcasting correctly
57+
return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype)
58+
5259
__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot',
5360
'vecdot', 'solve']
5461

0 commit comments

Comments
 (0)