11
11
12
12
from typing import TYPE_CHECKING
13
13
if TYPE_CHECKING :
14
- from typing import List , Optional , Tuple , Union
14
+ from typing import List , Optional , Sequence , Tuple , Union
15
15
from ..common ._typing import Device
16
16
from torch import dtype as Dtype
17
17
@@ -83,14 +83,14 @@ def _f(x1, x2, /, **kwargs):
83
83
"""
84
84
return _f
85
85
86
- def _fix_promotion (x1 , x2 ):
86
+ def _fix_promotion (x1 , x2 , only_scalar = True ):
87
87
if x1 .dtype not in _array_api_dtypes or x2 .dtype not in _array_api_dtypes :
88
88
return x1 , x2
89
89
# If an argument is 0-D pytorch downcasts the other argument
90
- if x1 .shape == ():
90
+ if not only_scalar or x1 .shape == ():
91
91
dtype = result_type (x1 , x2 )
92
92
x2 = x2 .to (dtype )
93
- if x2 .shape == ():
93
+ if not only_scalar or x2 .shape == ():
94
94
dtype = result_type (x1 , x2 )
95
95
x1 = x1 .to (dtype )
96
96
return x1 , x2
@@ -565,13 +565,22 @@ def unique_values(x: array) -> array:
565
565
566
566
def matmul (x1 : array , x2 : array , / , ** kwargs ) -> array :
567
567
# torch.matmul doesn't type promote (but differently from _fix_promotion)
568
- dtype = result_type (x1 , x2 )
569
- x1 = x1 .to (dtype )
570
- x2 = x2 .to (dtype )
568
+ x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
571
569
return torch .matmul (x1 , x2 , ** kwargs )
572
570
573
571
matrix_transpose = get_xp (torch )(_aliases_matrix_transpose )
574
- vecdot = get_xp (torch )(_aliases_vecdot )
572
+ _vecdot = get_xp (torch )(_aliases_vecdot )
573
+
574
+ def vecdot (x1 : array , x2 : array , / , * , axis : int = - 1 ) -> array :
575
+ x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
576
+ return _vecdot (x1 , x2 , axis = axis )
577
+
578
+ # torch.tensordot uses dims instead of axes
579
+ def tensordot (x1 : array , x2 : array , / , * , axes : Union [int , Tuple [Sequence [int ], Sequence [int ]]] = 2 , ** kwargs ) -> array :
580
+ # Note: torch.tensordot fails with integer dtypes when there is only 1
581
+ # element in the axis (https://github.com/pytorch/pytorch/issues/84530).
582
+ x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
583
+ return torch .tensordot (x1 , x2 , dims = axes , ** kwargs )
575
584
576
585
__all__ = ['result_type' , 'can_cast' , 'permute_dims' , 'bitwise_invert' , 'add' ,
577
586
'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
@@ -583,4 +592,4 @@ def matmul(x1: array, x2: array, /, **kwargs) -> array:
583
592
'nonzero' , 'where' , 'arange' , 'eye' , 'linspace' , 'full' ,
584
593
'expand_dims' , 'astype' , 'broadcast_arrays' , 'unique_all' ,
585
594
'unique_counts' , 'unique_inverse' , 'unique_values' ,
586
- 'matmul' , 'matrix_transpose' , 'vecdot' ]
595
+ 'matmul' , 'matrix_transpose' , 'vecdot' , 'tensordot' ]
0 commit comments