16
16
from functools import partial
17
17
import itertools
18
18
import math
19
- import unittest
20
19
21
20
from absl .testing import absltest
22
21
from absl .testing import parameterized
@@ -351,9 +350,6 @@ def test_coo_sorted_indices(self):
351
350
mat_resorted = mat_unsorted ._sort_indices ()
352
351
self .assertArraysEqual (mat .todense (), mat_resorted .todense ())
353
352
354
- @unittest .skipIf (
355
- not sptu .GPU_LOWERING_ENABLED , "test requires cusparse/hipsparse"
356
- )
357
353
@jtu .run_on_devices ("gpu" )
358
354
def test_coo_sorted_indices_gpu_lowerings (self ):
359
355
dtype = jnp .float32
@@ -544,9 +540,7 @@ def test_coo_matmul_ad(self, shape, dtype, bshape):
544
540
dtype = _lowerings .SUPPORTED_DATA_DTYPES ,
545
541
transpose = [True , False ],
546
542
)
547
- @unittest .skipIf (
548
- not sptu .GPU_LOWERING_ENABLED , "test requires cusparse/hipsparse"
549
- )
543
+ @jtu .run_on_devices ("gpu" )
550
544
def test_coo_spmv (self , shape , dtype , transpose ):
551
545
rng_sparse = sptu .rand_sparse (self .rng ())
552
546
rng_dense = jtu .rand_default (self .rng ())
@@ -569,9 +563,7 @@ def test_coo_spmv(self, shape, dtype, transpose):
569
563
dtype = _lowerings .SUPPORTED_DATA_DTYPES ,
570
564
transpose = [True , False ],
571
565
)
572
- @unittest .skipIf (
573
- not sptu .GPU_LOWERING_ENABLED , "test requires cusparse/hipsparse"
574
- )
566
+ @jtu .run_on_devices ("gpu" )
575
567
def test_coo_spmm (self , shape , dtype , transpose ):
576
568
rng_sparse = sptu .rand_sparse (self .rng ())
577
569
rng_dense = jtu .rand_default (self .rng ())
@@ -594,9 +586,7 @@ def test_coo_spmm(self, shape, dtype, transpose):
594
586
dtype = _lowerings .SUPPORTED_DATA_DTYPES ,
595
587
transpose = [True , False ],
596
588
)
597
- @unittest .skipIf (
598
- not sptu .GPU_LOWERING_ENABLED , "test requires cusparse/hipsparse"
599
- )
589
+ @jtu .run_on_devices ("gpu" )
600
590
def test_csr_spmv (self , shape , dtype , transpose ):
601
591
rng_sparse = sptu .rand_sparse (self .rng ())
602
592
rng_dense = jtu .rand_default (self .rng ())
@@ -617,9 +607,7 @@ def test_csr_spmv(self, shape, dtype, transpose):
617
607
dtype = _lowerings .SUPPORTED_DATA_DTYPES ,
618
608
transpose = [True , False ],
619
609
)
620
- @unittest .skipIf (
621
- not sptu .GPU_LOWERING_ENABLED , "test requires cusparse/hipsparse"
622
- )
610
+ @jtu .run_on_devices ("gpu" )
623
611
def test_csr_spmm (self , shape , dtype , transpose ):
624
612
rng_sparse = sptu .rand_sparse (self .rng ())
625
613
rng_dense = jtu .rand_default (self .rng ())
@@ -1083,8 +1071,6 @@ class SparseSolverTest(sptu.SparseTestCase):
1083
1071
)
1084
1072
@jtu .run_on_devices ("cpu" , "cuda" )
1085
1073
def test_sparse_qr_linear_solver (self , size , reorder , dtype ):
1086
- if jtu .test_device_matches (["cuda" ]) and not sptu .GPU_LOWERING_ENABLED :
1087
- raise unittest .SkipTest ('test requires cusparse/cusolver' )
1088
1074
rng = sptu .rand_sparse (self .rng ())
1089
1075
a = rng ((size , size ), dtype )
1090
1076
nse = (a != 0 ).sum ()
@@ -1110,8 +1096,6 @@ def sparse_solve(data, indices, indptr, b):
1110
1096
)
1111
1097
@jtu .run_on_devices ("cpu" , "cuda" )
1112
1098
def test_sparse_qr_linear_solver_grads (self , size , dtype ):
1113
- if jtu .test_device_matches (["cuda" ]) and not sptu .GPU_LOWERING_ENABLED :
1114
- raise unittest .SkipTest ('test requires cusparse/cusolver' )
1115
1099
rng = sptu .rand_sparse (self .rng ())
1116
1100
a = rng ((size , size ), dtype )
1117
1101
nse = (a != 0 ).sum ()
0 commit comments