|
17 | 17 | from hypothesis import assume, given
|
18 | 18 | from hypothesis.strategies import (booleans, composite, tuples, floats,
|
19 | 19 | integers, shared, sampled_from, one_of,
|
20 |
| - data, just) |
| 20 | + data) |
21 | 21 | from ndindex import iter_indices
|
22 | 22 |
|
23 | 23 | import itertools
|
|
29 | 29 | invertible_matrices, two_mutual_arrays,
|
30 | 30 | mutually_promotable_dtypes, one_d_shapes,
|
31 | 31 | two_mutually_broadcastable_shapes,
|
| 32 | + mutually_broadcastable_shapes, |
32 | 33 | SQRT_MAX_ARRAY_SIZE, finite_matrices,
|
33 | 34 | rtol_shared_matrix_shapes, rtols, axes)
|
34 | 35 | from . import dtype_helpers as dh
|
@@ -756,20 +757,33 @@ def true_trace(x_stack):
|
756 | 757 |
|
757 | 758 |
|
758 | 759 | @given(
|
759 |
| - dtypes=mutually_promotable_dtypes(dtypes=dh.numeric_dtypes), |
760 |
| - shape=shapes(min_dims=1), |
761 |
| - data=data(), |
| 760 | + *two_mutual_arrays(dh.numeric_dtypes, mutually_broadcastable_shapes(2, min_dims=1)), |
| 761 | + kwargs(axis=integers()), |
762 | 762 | )
|
763 |
| -def test_vecdot(dtypes, shape, data): |
| 763 | +def test_vecdot(x1, x2, kw): |
764 | 764 | # TODO: vary shapes, test different axis arguments
|
765 |
| - x1 = data.draw(xps.arrays(dtype=dtypes[0], shape=shape), label="x1") |
766 |
| - x2 = data.draw(xps.arrays(dtype=dtypes[1], shape=shape), label="x2") |
767 |
| - kw = data.draw(kwargs(axis=just(-1))) |
| 765 | + broadcasted_shape = sh.broadcast_shapes(x1.shape, x2.shape) |
| 766 | + ndim = len(broadcasted_shape) |
| 767 | + axis = kw.get('axis', -1) |
| 768 | + if not (-ndim <= axis < ndim): |
| 769 | + ph.raises(Exception, lambda: xp.vecdot(x1, x2, **kw), |
| 770 | + f"vecdot did not raise an exception for invalid axis ({ndim=}, {kw=})") |
| 771 | + return |
| 772 | + x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) |
| 773 | + x2_shape = (1,)*(ndim - x1.ndim) + tuple(x2.shape) |
| 774 | + if x1_shape[axis] != x2_shape[axis]: |
| 775 | + ph.raises(Exception, lambda: xp.vecdot(x1, x2, **kw), |
| 776 | + "vecdot did not raise an exception for invalid shapes") |
| 777 | + return |
| 778 | + expected_shape = list(broadcasted_shape) |
| 779 | + expected_shape.pop(axis) |
| 780 | + expected_shape = tuple(expected_shape) |
768 | 781 |
|
769 | 782 | out = xp.vecdot(x1, x2, **kw)
|
770 | 783 |
|
771 |
| - ph.assert_dtype("vecdot", dtypes, out.dtype) |
| 784 | + ph.assert_dtype("vecdot", [x1.dtype, x2.dtype], out.dtype) |
772 | 785 | # TODO: assert shape and elements
|
| 786 | + ph.assert_shape("vecdot", out.shape, expected_shape) |
773 | 787 |
|
774 | 788 | # Insanely large orders might not work. There isn't a limit specified in the
|
775 | 789 | # spec, so we just limit to reasonable values here.
|
|
0 commit comments