Skip to content

Commit 453ecb8

Browse files
committed
Add torch wrapper for matmul
1 parent 04eef18 commit 453ecb8

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,12 @@ def unique_inverse(x: array) -> UniqueInverseResult:
563563
def unique_values(x: array) -> array:
564564
return torch.unique(x)
565565

566+
def matmul(x1: array, x2: array, /, **kwargs) -> array:
567+
# torch.matmul doesn't type promote (but differently from _fix_promotion)
568+
dtype = result_type(x1, x2)
569+
x1 = x1.to(dtype)
570+
x2 = x2.to(dtype)
571+
return torch.matmul(x1, x2, **kwargs)
566572

567573
matrix_transpose = get_xp(torch)(_aliases_matrix_transpose)
568574
vecdot = get_xp(torch)(_aliases_vecdot)
@@ -577,4 +583,4 @@ def unique_values(x: array) -> array:
577583
'nonzero', 'where', 'arange', 'eye', 'linspace', 'full',
578584
'expand_dims', 'astype', 'broadcast_arrays', 'unique_all',
579585
'unique_counts', 'unique_inverse', 'unique_values',
580-
'matrix_transpose', 'vecdot']
586+
'matmul', 'matrix_transpose', 'vecdot']

0 commit comments

Comments
 (0)