We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f12be47 commit d41d0bdCopy full SHA for d41d0bd
array_api_tests/test_linalg.py
@@ -779,11 +779,20 @@ def test_vecdot(x1, x2, kw):
779
expected_shape.pop(axis)
780
expected_shape = tuple(expected_shape)
781
782
- out = xp.vecdot(x1, x2, **kw)
+ res = xp.vecdot(x1, x2, **kw)
783
784
- ph.assert_dtype("vecdot", [x1.dtype, x2.dtype], out.dtype)
+ ph.assert_dtype("vecdot", [x1.dtype, x2.dtype], res.dtype)
785
# TODO: assert shape and elements
786
- ph.assert_shape("vecdot", out.shape, expected_shape)
+ ph.assert_shape("vecdot", res.shape, expected_shape)
787
+
788
+ if x1.dtype in dh.int_dtypes:
789
+ def true_val(x, y, axix=-1):
790
+ return xp.sum(x*y, dtype=res.dtype)
791
+ else:
792
+ true_val = None
793
794
+ _test_stacks(linalg.vecdot, x1, x2, res=res, dims=0,
795
+ matrix_axes=(axis,), true_val=true_val)
796
797
# Insanely large orders might not work. There isn't a limit specified in the
798
# spec, so we just limit to reasonable values here.
0 commit comments