-
We have following test failure on ROCm. I am not filing the bug report as I am unsure self = <sparse_test.SparseObjectTest testMethod=test_matmul16>, shape = (5, 8)
dtype = <class 'numpy.complex64'>
Obj = <class 'jax.experimental.sparse.csr.CSC'>, bshape = (8, 3)
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(shape=shape, bshape=bshape)
for shape in [(5, 8), (8, 5), (5, 5), (8, 8)]
for bshape in [shape[-1:] + s for s in [(), (3,), (4,)]]
],
Obj=[Obj],
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
)
for Obj in [sparse.CSR, sparse.CSC, sparse.COO, sparse.BCOO]))
@jax.default_matmul_precision("float32")
def test_matmul(self, shape, dtype, Obj, bshape):
rng = sptu.rand_sparse(self.rng(), post=jnp.array)
rng_b = jtu.rand_default(self.rng())
M = rng(shape, dtype)
Msp = Obj.fromdense(M)
# Test matching type
x = rng_b(bshape, dtype)
x = jnp.asarray(x)
> self.assertAllClose(
M @ x, Msp @ x, rtol=sptu.MATMUL_TOL, atol=sptu.MATMUL_TOL
)
tests/sparse_test.py:969:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jax/_src/test_util.py:1103: in assertAllClose
self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol,
jax/_src/test_util.py:1068: in assertArraysAllClose
_assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg)
jax/_src/public_test_util.py:127: in _assert_numpy_allclose
np.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (<function assert_allclose.<locals>.compare at 0x7f508885dca0>, array([[ -2.2831573 -3.6853049j, 4.245804 -2.4329...anj],
[ nan +nanj, nan +nanj,
nan +nanj]], dtype=complex64))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-05, atol=1e-05', 'verbose': True}
@wraps(func)
def inner(*args, **kwds):
with self._recreate_cm():
> return func(*args, **kwds)
E AssertionError:
E Not equal to tolerance rtol=1e-05, atol=1e-05
E
E x and y nan location mismatch:
E x: array([[ -2.283157 -3.685305j, 4.245804 -2.432906j,
E 24.530151+43.617775j],
E [ 0.122184 +4.359858j, -0.106578 -2.045549j,...
E y: array([[ -2.283158 -3.685305j, 4.245805 -2.432905j,
E 24.53015 +43.617775j],
E [ 0.122184 +4.359858j, -0.106578 -2.045549j,...
/pyenv/versions/3.9.0/lib/python3.9/contextlib.py:79: AssertionError
The problematic function is as follows -> x = rng_b(bshape, dtype)
x = jnp.asarray(x)
self.assertAllClose(
M @ x, Msp @ x, rtol=sptu.MATMUL_TOL, atol=sptu.MATMUL_TOL
)
# Test mismatched type
x = rng_b(bshape, np.int32)
x = jnp.asarray(x)
with jax.numpy_dtype_promotion('standard'):
self.assertAllClose(M @ x, Msp @ x, rtol=sptu.MATMUL_TOL) The first test passes, but the second fails. Do you think that there is a problem with jax 0.4.26 |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
This is unrelated to We don't have any rocm CI tests, because we don't have easy access to CI environments with this architecture. I also personally don't have any access to ROCM chips, so I'm unable to debug this issue. If you'd like to look into it, we'd be happy to accept a pull request either fixing the test or marking it to skip. Thanks! |
Beta Was this translation helpful? Give feedback.
This is unrelated to
numpy_dtype_promotion
; It looks like there is a difference in NaN behavior between rocm and the reference implementation.We don't have any rocm CI tests, because we don't have easy access to CI environments with this architecture. I also personally don't have any access to ROCM chips, so I'm unable to debug this issue. If you'd like to look into it, we'd be happy to accept a pull request either fixing the test or marking it to skip. Thanks!