Skip to content

Commit 74add08

Browse files
committed
Fix test_vecdot
The NumPy implementation is currently incorrect, so I am not 100% if this test is completely correct.
1 parent f494b45 commit 74add08

File tree

1 file changed

+23
-9
lines changed

1 file changed

+23
-9
lines changed

array_api_tests/test_linalg.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from hypothesis import assume, given
1818
from hypothesis.strategies import (booleans, composite, tuples, floats,
1919
integers, shared, sampled_from, one_of,
20-
data, just)
20+
data)
2121
from ndindex import iter_indices
2222

2323
import itertools
@@ -29,6 +29,7 @@
2929
invertible_matrices, two_mutual_arrays,
3030
mutually_promotable_dtypes, one_d_shapes,
3131
two_mutually_broadcastable_shapes,
32+
mutually_broadcastable_shapes,
3233
SQRT_MAX_ARRAY_SIZE, finite_matrices,
3334
rtol_shared_matrix_shapes, rtols, axes)
3435
from . import dtype_helpers as dh
@@ -756,20 +757,33 @@ def true_trace(x_stack):
756757

757758

758759
@given(
759-
dtypes=mutually_promotable_dtypes(dtypes=dh.numeric_dtypes),
760-
shape=shapes(min_dims=1),
761-
data=data(),
760+
*two_mutual_arrays(dh.numeric_dtypes, mutually_broadcastable_shapes(2, min_dims=1)),
761+
kwargs(axis=integers()),
762762
)
763-
def test_vecdot(dtypes, shape, data):
763+
def test_vecdot(x1, x2, kw):
764764
# TODO: vary shapes, test different axis arguments
765-
x1 = data.draw(xps.arrays(dtype=dtypes[0], shape=shape), label="x1")
766-
x2 = data.draw(xps.arrays(dtype=dtypes[1], shape=shape), label="x2")
767-
kw = data.draw(kwargs(axis=just(-1)))
765+
broadcasted_shape = sh.broadcast_shapes(x1.shape, x2.shape)
766+
ndim = len(broadcasted_shape)
767+
axis = kw.get('axis', -1)
768+
if not (-ndim <= axis < ndim):
769+
ph.raises(Exception, lambda: xp.vecdot(x1, x2, **kw),
770+
f"vecdot did not raise an exception for invalid axis ({ndim=}, {kw=})")
771+
return
772+
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
773+
x2_shape = (1,)*(ndim - x1.ndim) + tuple(x2.shape)
774+
if x1_shape[axis] != x2_shape[axis]:
775+
ph.raises(Exception, lambda: xp.vecdot(x1, x2, **kw),
776+
"vecdot did not raise an exception for invalid shapes")
777+
return
778+
expected_shape = list(broadcasted_shape)
779+
expected_shape.pop(axis)
780+
expected_shape = tuple(expected_shape)
768781

769782
out = xp.vecdot(x1, x2, **kw)
770783

771-
ph.assert_dtype("vecdot", dtypes, out.dtype)
784+
ph.assert_dtype("vecdot", [x1.dtype, x2.dtype], out.dtype)
772785
# TODO: assert shape and elements
786+
ph.assert_shape("vecdot", out.shape, expected_shape)
773787

774788
# Insanely large orders might not work. There isn't a limit specified in the
775789
# spec, so we just limit to reasonable values here.

0 commit comments

Comments
 (0)