Skip to content

Commit 5923e28

Browse files
committed
Rationalize and try to fix failing ldiv tests
1 parent e561e7a commit 5923e28

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

test/libraries/cusparse/interfaces.jl

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -214,17 +214,15 @@ nB = 2
214214
ldiv!(triangle(opa(dA)), dz)
215215
@test z collect(dz)
216216
end
217-
if SparseMatrixType != CuSparseMatrixBSR
217+
# seems to be a library bug in CUDAs 12.0-12.2, only fp64 types are supported
218+
if SparseMatrixType != CuSparseMatrixBSR || elty (Float64, ComplexF64) || CUSPARSE.version() < v"12.0" || v"12.2" < CUSPARSE.version()
218219
@testset "ldiv! -- (CuVector, CuVector)" begin
219220
z = rand(elty, m)
220221
dz = CuArray(z)
221222
ldiv!(z, triangle(opa(A)), y)
222223
ldiv!(dz, triangle(opa(dA)), dy)
223224
@test z collect(dz)
224225
end
225-
end
226-
# seems to be a library bug in CUDAs 12.0-12.2, only fp64 types are supported
227-
if SparseMatrixType == CuSparseMatrixBSR || elty (Float64, ComplexF64) || CUSPARSE.version() < v"12.0" || v"12.2" < CUSPARSE.version()
228226
@testset "\\ -- CuVector" begin
229227
x = triangle(opa(A)) \ y
230228
dx = triangle(opa(dA)) \ dy
@@ -248,7 +246,7 @@ nB = 2
248246
@test_throws DimensionMismatch(error_str) ldiv!(triangle(opa(dA)), opb(dB_bad))
249247
end
250248
end
251-
if SparseMatrixType != CuSparseMatrixBSR
249+
if SparseMatrixType != CuSparseMatrixBSR || elty (Float64, ComplexF64) || CUSPARSE.version() != v"12.0"
252250
@testset "ldiv! -- (CuMatrix, CuMatrix)" begin
253251
C = rand(elty, m, nB)
254252
dC = CuArray(C)
@@ -259,13 +257,13 @@ nB = 2
259257
@test_throws DimensionMismatch(error_str) ldiv!(triangle(opa(dA)), opb(dB_bad))
260258
end
261259
end
262-
end
263-
@testset "\\ -- CuMatrix" begin
264-
C = triangle(opa(A)) \ opb(B)
265-
dC = triangle(opa(dA)) \ opb(dB)
266-
@test C collect(dC)
267-
if CUSPARSE.version() < v"12.0"
268-
@test_throws DimensionMismatch(error_str) ldiv!(triangle(opa(dA)), opb(dB_bad))
260+
@testset "\\ -- CuMatrix" begin
261+
C = triangle(opa(A)) \ opb(B)
262+
dC = triangle(opa(dA)) \ opb(dB)
263+
@test C collect(dC)
264+
if CUSPARSE.version() < v"12.0"
265+
@test_throws DimensionMismatch(error_str) ldiv!(triangle(opa(dA)), opb(dB_bad))
266+
end
269267
end
270268
end
271269
end

0 commit comments

Comments
 (0)