|
45 | 45 |
|
46 | 46 | def assert_equal(x, y, msg_extra=None):
|
47 | 47 | extra = '' if not msg_extra else f' ({msg_extra})'
|
48 |
| - if x.dtype in dh.float_dtypes: |
| 48 | + if x.dtype in dh.all_float_dtypes: |
49 | 49 | # It's too difficult to do an approximately equal test here because
|
50 | 50 | # different routines can give completely different answers, and even
|
51 | 51 | # when it does work, the elementwise comparisons are too slow. So for
|
@@ -701,7 +701,8 @@ def test_tensordot(x1, x2, kw):
|
701 | 701 | # TODO: vary shapes, vary contracted axes, test different axes arguments
|
702 | 702 | res = xp.tensordot(x1, x2, **kw)
|
703 | 703 |
|
704 |
| - ph.assert_dtype("tensordot", [x1.dtype, x2.dtype], res.dtype) |
| 704 | + ph.assert_dtype("tensordot", in_dtype=[x1.dtype, x2.dtype], |
| 705 | + out_dtype=res.dtype) |
705 | 706 |
|
706 | 707 | axes = _axes = kw.get('axes', 2)
|
707 | 708 |
|
@@ -785,9 +786,10 @@ def test_vecdot(x1, x2, kw):
|
785 | 786 |
|
786 | 787 | res = xp.vecdot(x1, x2, **kw)
|
787 | 788 |
|
788 |
| - ph.assert_dtype("vecdot", [x1.dtype, x2.dtype], res.dtype) |
| 789 | + ph.assert_dtype("vecdot", in_dtype=[x1.dtype, x2.dtype], |
| 790 | + out_dtype=res.dtype) |
789 | 791 | # TODO: assert shape and elements
|
790 |
| - ph.assert_shape("vecdot", res.shape, expected_shape) |
| 792 | + ph.assert_shape("vecdot", out_shape=res.shape, expected=expected_shape) |
791 | 793 |
|
792 | 794 | if x1.dtype in dh.int_dtypes:
|
793 | 795 | def true_val(x, y, axis=-1):
|
@@ -827,9 +829,11 @@ def test_vector_norm(x, data):
|
827 | 829 |
|
828 | 830 | _axes = sh.normalise_axis(axis, x.ndim)
|
829 | 831 |
|
830 |
| - ph.assert_keepdimable_shape('linalg.vector_norm', res.shape, x.shape, |
831 |
| - _axes, keepdims, **kw) |
832 |
| - ph.assert_dtype('linalg.vector_norm', x.dtype, res.dtype) |
| 832 | + ph.assert_keepdimable_shape('linalg.vector_norm', out_shape=res.shape, |
| 833 | + in_shape=x.shape, axes=_axes, |
| 834 | + keepdims=keepdims, kw=kw) |
| 835 | + ph.assert_dtype('linalg.vector_norm', in_dtype=x.dtype, |
| 836 | + out_dtype=res.dtype) |
833 | 837 |
|
834 | 838 | _kw = kw.copy()
|
835 | 839 | _kw.pop('axis', None)
|
|
0 commit comments