We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 6e06805 commit 121f54fCopy full SHA for 121f54f
array_api_strict/linalg.py
@@ -379,7 +379,18 @@ def trace(x: Array, /, *, offset: int = 0) -> Array:
379
def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
380
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
381
raise TypeError('Only numeric dtypes are allowed in vecdot')
382
- return tensordot(x1, x2, axes=((axis,), (axis,)))
+ ndim = max(x1.ndim, x2.ndim)
383
+ x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
384
+ x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
385
+ if x1_shape[axis] != x2_shape[axis]:
386
+ raise ValueError("x1 and x2 must have the same size along the given axis")
387
+
388
+ x1_, x2_ = np.broadcast_arrays(x1._array, x2._array)
389
+ x1_ = np.moveaxis(x1_, axis, -1)
390
+ x2_ = np.moveaxis(x2_, axis, -1)
391
392
+ res = x1_[..., None, :] @ x2_[..., None]
393
+ return Array._new(res[..., 0, 0])
394
395
396
# Note: the name here is different from norm(). The array API norm is split
0 commit comments