Skip to content

Commit 3cb9912

Browse files
committed
Fix test_cross to use assert_dtype and assert_shape helpers
1 parent afc8a25 commit 3cb9912

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

array_api_tests/test_linalg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,9 @@ def test_cross(x1_x2_kw):
191191

192192
broadcasted_shape = sh.broadcast_shapes(x1.shape, x2.shape)
193193

194-
assert res.dtype == dh.result_type(x1.dtype, x2.dtype), "cross() did not return the correct dtype"
195-
assert res.shape == broadcasted_shape, "cross() did not return the correct shape"
194+
ph.assert_dtype("cross", in_dtype=[x1.dtype, x2.dtype],
195+
out_dtype=res.dtype)
196+
ph.assert_shape("cross", out_shape=res.shape, expected=broadcasted_shape)
196197

197198
def exact_cross(a, b):
198199
assert a.shape == b.shape == (3,), "Invalid cross() stack shapes. This indicates a bug in the test suite."
@@ -800,7 +801,6 @@ def test_vecdot(x1, x2, data):
800801

801802
ph.assert_dtype("vecdot", in_dtype=[x1.dtype, x2.dtype],
802803
out_dtype=res.dtype)
803-
# TODO: assert shape and elements
804804
ph.assert_shape("vecdot", out_shape=res.shape, expected=expected_shape)
805805

806806
if x1.dtype in dh.int_dtypes:

0 commit comments

Comments
 (0)