Skip to content

Commit c09a45a

Browse files
author
jax authors
committed
Merge pull request #20673 from Micky774:sparse_test
PiperOrigin-RevId: 623326571
2 parents c246a97 + a7737ca commit c09a45a

File tree

3 files changed

+5
-37
lines changed

3 files changed

+5
-37
lines changed

jax/experimental/sparse/test_util.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,6 @@
4141
np.complex128: 1e-10,
4242
}
4343

44-
GPU_LOWERING_ENABLED = gpu_sparse and (
45-
gpu_sparse.cuda_is_supported or gpu_sparse.rocm_is_supported
46-
)
47-
4844

4945
def is_sparse(x):
5046
return isinstance(x, sparse.JAXSparse)

tests/sparse_bcoo_bcsr_test.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def _is_required_cuda_version_satisfied(cuda_version):
151151
class BCOOTest(sptu.SparseTestCase):
152152

153153
def gpu_matmul_warning_context(self, msg):
154-
if sptu.GPU_LOWERING_ENABLED and config.jax_bcoo_cusparse_lowering:
154+
if config.jax_bcoo_cusparse_lowering:
155155
return self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, msg)
156156
return contextlib.nullcontext()
157157

@@ -479,9 +479,6 @@ def test_bcoo_dot_general(
479479
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
480480
)
481481
@jax.default_matmul_precision("float32")
482-
@unittest.skipIf(
483-
not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse"
484-
)
485482
@jtu.run_on_devices("gpu")
486483
def test_bcoo_dot_general_cusparse(
487484
self, lhs_shape, rhs_shape, dtype, lhs_contracting, rhs_contracting
@@ -528,9 +525,6 @@ def f_sparse(lhs_bcoo, lhs, rhs):
528525
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
529526
)
530527
@jax.default_matmul_precision("float32")
531-
@unittest.skipIf(
532-
not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse"
533-
)
534528
@jtu.run_on_devices("gpu")
535529
def test_bcoo_batched_matmat_cusparse(
536530
self,
@@ -581,9 +575,6 @@ def f_sparse(lhs_bcoo, lhs, rhs):
581575
],
582576
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
583577
)
584-
@unittest.skipIf(
585-
not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse"
586-
)
587578
@jtu.run_on_devices("gpu")
588579
def test_bcoo_batched_matmat_default_lowering(
589580
self,
@@ -615,9 +606,6 @@ def test_bcoo_batched_matmat_default_lowering(
615606
matmat_default_lowering_fallback = sp_matmat(lhs_bcoo, rhs)
616607
self.assertArraysEqual(matmat_expected, matmat_default_lowering_fallback)
617608

618-
@unittest.skipIf(
619-
not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse"
620-
)
621609
@jtu.run_on_devices("gpu")
622610
def test_bcoo_dot_general_oob_and_unsorted_indices_cusparse(self):
623611
"""Tests bcoo dot general with out-of-bound and unsorted indices."""

tests/sparse_test.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from functools import partial
1717
import itertools
1818
import math
19-
import unittest
2019

2120
from absl.testing import absltest
2221
from absl.testing import parameterized
@@ -351,9 +350,6 @@ def test_coo_sorted_indices(self):
351350
mat_resorted = mat_unsorted._sort_indices()
352351
self.assertArraysEqual(mat.todense(), mat_resorted.todense())
353352

354-
@unittest.skipIf(
355-
not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse"
356-
)
357353
@jtu.run_on_devices("gpu")
358354
def test_coo_sorted_indices_gpu_lowerings(self):
359355
dtype = jnp.float32
@@ -544,9 +540,7 @@ def test_coo_matmul_ad(self, shape, dtype, bshape):
544540
dtype=_lowerings.SUPPORTED_DATA_DTYPES,
545541
transpose=[True, False],
546542
)
547-
@unittest.skipIf(
548-
not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse"
549-
)
543+
@jtu.run_on_devices("gpu")
550544
def test_coo_spmv(self, shape, dtype, transpose):
551545
rng_sparse = sptu.rand_sparse(self.rng())
552546
rng_dense = jtu.rand_default(self.rng())
@@ -569,9 +563,7 @@ def test_coo_spmv(self, shape, dtype, transpose):
569563
dtype=_lowerings.SUPPORTED_DATA_DTYPES,
570564
transpose=[True, False],
571565
)
572-
@unittest.skipIf(
573-
not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse"
574-
)
566+
@jtu.run_on_devices("gpu")
575567
def test_coo_spmm(self, shape, dtype, transpose):
576568
rng_sparse = sptu.rand_sparse(self.rng())
577569
rng_dense = jtu.rand_default(self.rng())
@@ -594,9 +586,7 @@ def test_coo_spmm(self, shape, dtype, transpose):
594586
dtype=_lowerings.SUPPORTED_DATA_DTYPES,
595587
transpose=[True, False],
596588
)
597-
@unittest.skipIf(
598-
not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse"
599-
)
589+
@jtu.run_on_devices("gpu")
600590
def test_csr_spmv(self, shape, dtype, transpose):
601591
rng_sparse = sptu.rand_sparse(self.rng())
602592
rng_dense = jtu.rand_default(self.rng())
@@ -617,9 +607,7 @@ def test_csr_spmv(self, shape, dtype, transpose):
617607
dtype=_lowerings.SUPPORTED_DATA_DTYPES,
618608
transpose=[True, False],
619609
)
620-
@unittest.skipIf(
621-
not sptu.GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse"
622-
)
610+
@jtu.run_on_devices("gpu")
623611
def test_csr_spmm(self, shape, dtype, transpose):
624612
rng_sparse = sptu.rand_sparse(self.rng())
625613
rng_dense = jtu.rand_default(self.rng())
@@ -1083,8 +1071,6 @@ class SparseSolverTest(sptu.SparseTestCase):
10831071
)
10841072
@jtu.run_on_devices("cpu", "cuda")
10851073
def test_sparse_qr_linear_solver(self, size, reorder, dtype):
1086-
if jtu.test_device_matches(["cuda"]) and not sptu.GPU_LOWERING_ENABLED:
1087-
raise unittest.SkipTest('test requires cusparse/cusolver')
10881074
rng = sptu.rand_sparse(self.rng())
10891075
a = rng((size, size), dtype)
10901076
nse = (a != 0).sum()
@@ -1110,8 +1096,6 @@ def sparse_solve(data, indices, indptr, b):
11101096
)
11111097
@jtu.run_on_devices("cpu", "cuda")
11121098
def test_sparse_qr_linear_solver_grads(self, size, dtype):
1113-
if jtu.test_device_matches(["cuda"]) and not sptu.GPU_LOWERING_ENABLED:
1114-
raise unittest.SkipTest('test requires cusparse/cusolver')
11151099
rng = sptu.rand_sparse(self.rng())
11161100
a = rng((size, size), dtype)
11171101
nse = (a != 0).sum()

0 commit comments

Comments
 (0)