Skip to content

Commit 5a3bbbe

Browse files
committed
Finish torch wrappers for matmul, vecdot, and tensordot
1 parent 453ecb8 commit 5a3bbbe

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from typing import TYPE_CHECKING
1313
if TYPE_CHECKING:
14-
from typing import List, Optional, Tuple, Union
14+
from typing import List, Optional, Sequence, Tuple, Union
1515
from ..common._typing import Device
1616
from torch import dtype as Dtype
1717

@@ -83,14 +83,14 @@ def _f(x1, x2, /, **kwargs):
8383
"""
8484
return _f
8585

86-
def _fix_promotion(x1, x2):
86+
def _fix_promotion(x1, x2, only_scalar=True):
8787
if x1.dtype not in _array_api_dtypes or x2.dtype not in _array_api_dtypes:
8888
return x1, x2
8989
# If an argument is 0-D pytorch downcasts the other argument
90-
if x1.shape == ():
90+
if not only_scalar or x1.shape == ():
9191
dtype = result_type(x1, x2)
9292
x2 = x2.to(dtype)
93-
if x2.shape == ():
93+
if not only_scalar or x2.shape == ():
9494
dtype = result_type(x1, x2)
9595
x1 = x1.to(dtype)
9696
return x1, x2
@@ -565,13 +565,22 @@ def unique_values(x: array) -> array:
565565

566566
def matmul(x1: array, x2: array, /, **kwargs) -> array:
567567
# torch.matmul doesn't type promote (but differently from _fix_promotion)
568-
dtype = result_type(x1, x2)
569-
x1 = x1.to(dtype)
570-
x2 = x2.to(dtype)
568+
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
571569
return torch.matmul(x1, x2, **kwargs)
572570

573571
matrix_transpose = get_xp(torch)(_aliases_matrix_transpose)
574-
vecdot = get_xp(torch)(_aliases_vecdot)
572+
_vecdot = get_xp(torch)(_aliases_vecdot)
573+
574+
def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
575+
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
576+
return _vecdot(x1, x2, axis=axis)
577+
578+
# torch.tensordot uses dims instead of axes
579+
def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, **kwargs) -> array:
580+
# Note: torch.tensordot fails with integer dtypes when there is only 1
581+
# element in the axis (https://github.com/pytorch/pytorch/issues/84530).
582+
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
583+
return torch.tensordot(x1, x2, dims=axes, **kwargs)
575584

576585
__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'add',
577586
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
@@ -583,4 +592,4 @@ def matmul(x1: array, x2: array, /, **kwargs) -> array:
583592
'nonzero', 'where', 'arange', 'eye', 'linspace', 'full',
584593
'expand_dims', 'astype', 'broadcast_arrays', 'unique_all',
585594
'unique_counts', 'unique_inverse', 'unique_values',
586-
'matmul', 'matrix_transpose', 'vecdot']
595+
'matmul', 'matrix_transpose', 'vecdot', 'tensordot']

0 commit comments

Comments
 (0)