3
3
from functools import wraps
4
4
from builtins import all as builtin_all
5
5
6
- from ..common ._aliases import (UniqueAllResult , UniqueCountsResult , UniqueInverseResult )
6
+ from ..common ._aliases import (UniqueAllResult , UniqueCountsResult ,
7
+ UniqueInverseResult ,
8
+ matrix_transpose as _aliases_matrix_transpose ,
9
+ vecdot as _aliases_vecdot )
10
+ from .._internal import get_xp
7
11
8
12
from typing import TYPE_CHECKING
9
13
if TYPE_CHECKING :
@@ -559,6 +563,10 @@ def unique_inverse(x: array) -> UniqueInverseResult:
559
563
def unique_values (x : array ) -> array :
560
564
return torch .unique (x )
561
565
566
+
567
+ matrix_transpose = get_xp (torch )(_aliases_matrix_transpose )
568
+ vecdot = get_xp (torch )(_aliases_vecdot )
569
+
562
570
__all__ = ['result_type' , 'can_cast' , 'permute_dims' , 'bitwise_invert' , 'add' ,
563
571
'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
564
572
'bitwise_right_shift' , 'bitwise_xor' , 'divide' , 'equal' ,
@@ -568,4 +576,5 @@ def unique_values(x: array) -> array:
568
576
'mean' , 'std' , 'var' , 'concat' , 'squeeze' , 'flip' , 'roll' ,
569
577
'nonzero' , 'where' , 'arange' , 'eye' , 'linspace' , 'full' ,
570
578
'expand_dims' , 'astype' , 'broadcast_arrays' , 'unique_all' ,
571
- 'unique_counts' , 'unique_inverse' , 'unique_values' ]
579
+ 'unique_counts' , 'unique_inverse' , 'unique_values' ,
580
+ 'matrix_transpose' , 'vecdot' ]
0 commit comments