Skip to content

Commit e157ccd

Browse files
authored
BUG: Fix the implementation of numpy.array_api.vecdot (#21928)
* Fix the implementation of numpy.array_api.vecdot See https://data-apis.org/array-api/latest/API_specification/generated/signatures.linear_algebra_functions.vecdot.html * Use moveaxis + matmul instead of einsum in vecdot Original NumPy Commit: 0e960b985843ff99db06f89eadaa9f387b5a65f8
1 parent 0f6691f commit e157ccd

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

array_api_strict/linalg.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,18 @@ def trace(x: Array, /, *, offset: int = 0) -> Array:
379379
def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
380380
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
381381
raise TypeError('Only numeric dtypes are allowed in vecdot')
382-
return tensordot(x1, x2, axes=((axis,), (axis,)))
382+
ndim = max(x1.ndim, x2.ndim)
383+
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
384+
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
385+
if x1_shape[axis] != x2_shape[axis]:
386+
raise ValueError("x1 and x2 must have the same size along the given axis")
387+
388+
x1_, x2_ = np.broadcast_arrays(x1._array, x2._array)
389+
x1_ = np.moveaxis(x1_, axis, -1)
390+
x2_ = np.moveaxis(x2_, axis, -1)
391+
392+
res = x1_[..., None, :] @ x2_[..., None]
393+
return Array._new(res[..., 0, 0])
383394

384395

385396
# Note: the name here is different from norm(). The array API norm is split

0 commit comments

Comments
 (0)