Skip to content

Commit 04eef18

Browse files
committed
Add main namespace linalg functions to the torch wrapper
1 parent c441f33 commit 04eef18

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

array_api_compat/common/_aliases.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,12 @@ def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray:
436436
if x1_shape[axis] != x2_shape[axis]:
437437
raise ValueError("x1 and x2 must have the same size along the given axis")
438438

439-
x1_, x2_ = xp.broadcast_arrays(x1, x2)
439+
if hasattr(xp, 'broadcast_tensors'):
440+
_broadcast = xp.broadcast_tensors
441+
else:
442+
_broadcast = xp.broadcast_arrays
443+
444+
x1_, x2_ = _broadcast(x1, x2)
440445
x1_ = xp.moveaxis(x1_, axis, -1)
441446
x2_ = xp.moveaxis(x2_, axis, -1)
442447

array_api_compat/torch/_aliases.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
from functools import wraps
44
from builtins import all as builtin_all
55

6-
from ..common._aliases import (UniqueAllResult, UniqueCountsResult, UniqueInverseResult)
6+
from ..common._aliases import (UniqueAllResult, UniqueCountsResult,
7+
UniqueInverseResult,
8+
matrix_transpose as _aliases_matrix_transpose,
9+
vecdot as _aliases_vecdot)
10+
from .._internal import get_xp
711

812
from typing import TYPE_CHECKING
913
if TYPE_CHECKING:
@@ -559,6 +563,10 @@ def unique_inverse(x: array) -> UniqueInverseResult:
559563
def unique_values(x: array) -> array:
560564
return torch.unique(x)
561565

566+
567+
matrix_transpose = get_xp(torch)(_aliases_matrix_transpose)
568+
vecdot = get_xp(torch)(_aliases_vecdot)
569+
562570
__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'add',
563571
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
564572
'bitwise_right_shift', 'bitwise_xor', 'divide', 'equal',
@@ -568,4 +576,5 @@ def unique_values(x: array) -> array:
568576
'mean', 'std', 'var', 'concat', 'squeeze', 'flip', 'roll',
569577
'nonzero', 'where', 'arange', 'eye', 'linspace', 'full',
570578
'expand_dims', 'astype', 'broadcast_arrays', 'unique_all',
571-
'unique_counts', 'unique_inverse', 'unique_values']
579+
'unique_counts', 'unique_inverse', 'unique_values',
580+
'matrix_transpose', 'vecdot']

0 commit comments

Comments
 (0)