Skip to content

Commit d41d0bd

Browse files
committed
Expand vecdot tests
1 parent f12be47 commit d41d0bd

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

array_api_tests/test_linalg.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -779,11 +779,20 @@ def test_vecdot(x1, x2, kw):
779779
expected_shape.pop(axis)
780780
expected_shape = tuple(expected_shape)
781781

782-
out = xp.vecdot(x1, x2, **kw)
782+
res = xp.vecdot(x1, x2, **kw)
783783

784-
ph.assert_dtype("vecdot", [x1.dtype, x2.dtype], out.dtype)
784+
ph.assert_dtype("vecdot", [x1.dtype, x2.dtype], res.dtype)
785785
# TODO: assert shape and elements
786-
ph.assert_shape("vecdot", out.shape, expected_shape)
786+
ph.assert_shape("vecdot", res.shape, expected_shape)
787+
788+
if x1.dtype in dh.int_dtypes:
789+
def true_val(x, y, axix=-1):
790+
return xp.sum(x*y, dtype=res.dtype)
791+
else:
792+
true_val = None
793+
794+
_test_stacks(linalg.vecdot, x1, x2, res=res, dims=0,
795+
matrix_axes=(axis,), true_val=true_val)
787796

788797
# Insanely large orders might not work. There isn't a limit specified in the
789798
# spec, so we just limit to reasonable values here.

0 commit comments

Comments
 (0)