Skip to content

Commit cffd076

Browse files
committed
Fix some issues in linalg tests from recent merge
1 parent c51216b commit cffd076

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
sampled_from, shared, builds)
1313

1414
from . import _array_module as xp, api_version
15+
from . import array_helpers as ah
1516
from . import dtype_helpers as dh
1617
from . import shape_helpers as sh
1718
from . import xps

array_api_tests/test_linalg.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545

4646
def assert_equal(x, y, msg_extra=None):
4747
extra = '' if not msg_extra else f' ({msg_extra})'
48-
if x.dtype in dh.float_dtypes:
48+
if x.dtype in dh.all_float_dtypes:
4949
# It's too difficult to do an approximately equal test here because
5050
# different routines can give completely different answers, and even
5151
# when it does work, the elementwise comparisons are too slow. So for
@@ -701,7 +701,8 @@ def test_tensordot(x1, x2, kw):
701701
# TODO: vary shapes, vary contracted axes, test different axes arguments
702702
res = xp.tensordot(x1, x2, **kw)
703703

704-
ph.assert_dtype("tensordot", [x1.dtype, x2.dtype], res.dtype)
704+
ph.assert_dtype("tensordot", in_dtype=[x1.dtype, x2.dtype],
705+
out_dtype=res.dtype)
705706

706707
axes = _axes = kw.get('axes', 2)
707708

@@ -785,9 +786,10 @@ def test_vecdot(x1, x2, kw):
785786

786787
res = xp.vecdot(x1, x2, **kw)
787788

788-
ph.assert_dtype("vecdot", [x1.dtype, x2.dtype], res.dtype)
789+
ph.assert_dtype("vecdot", in_dtype=[x1.dtype, x2.dtype],
790+
out_dtype=res.dtype)
789791
# TODO: assert shape and elements
790-
ph.assert_shape("vecdot", res.shape, expected_shape)
792+
ph.assert_shape("vecdot", out_shape=res.shape, expected=expected_shape)
791793

792794
if x1.dtype in dh.int_dtypes:
793795
def true_val(x, y, axis=-1):
@@ -827,9 +829,11 @@ def test_vector_norm(x, data):
827829

828830
_axes = sh.normalise_axis(axis, x.ndim)
829831

830-
ph.assert_keepdimable_shape('linalg.vector_norm', res.shape, x.shape,
831-
_axes, keepdims, **kw)
832-
ph.assert_dtype('linalg.vector_norm', x.dtype, res.dtype)
832+
ph.assert_keepdimable_shape('linalg.vector_norm', out_shape=res.shape,
833+
in_shape=x.shape, axes=_axes,
834+
keepdims=keepdims, kw=kw)
835+
ph.assert_dtype('linalg.vector_norm', in_dtype=x.dtype,
836+
out_dtype=res.dtype)
833837

834838
_kw = kw.copy()
835839
_kw.pop('axis', None)

0 commit comments

Comments
 (0)