Skip to content

Commit 9b487af

Browse files
authored
More tests for sparse matrix dimension checks (#2796)
1 parent 523337d commit 9b487af

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

test/libraries/cusparse/bmm.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,25 @@ if CUSPARSE.version() ≥ v"11.7.2"
1515
α = rand(elty)
1616
β = rand(elty)
1717

18+
@testset "Dimension checks" begin
19+
A1 = CuSparseMatrixCSR{elty}(sprand(elty, m, k, p))
20+
A2 = copy(A1)
21+
A2.nzVal = CUDA.rand(elty, size(A2.nzVal)...)
22+
A = cat(A1, A2; dims=3)
23+
24+
B = CUDA.rand(elty, k, n, 2)
25+
C = CUDA.rand(elty, m, n, 3)
26+
27+
@test_throws ArgumentError("C must have same batch-dimension as max(size(A,3)=$(size(A,3)), size(B,3)=$(size(B,3))), got $(size(C,3)).") CUSPARSE.bmm!('N', 'N', α, A, B, β, C, 'O')
28+
29+
C = CUDA.rand(elty, m, 1, 2)
30+
@test_throws ArgumentError("bmm! does not work for n==1 and b>1 due to CUDA error.") CUSPARSE.bmm!('N', 'N', α, A, B, β, C, 'O')
31+
32+
C = CUDA.rand(elty, m, n, 2)
33+
B = CUDA.rand(elty, k+1, n, 2)
34+
@test_throws DimensionMismatch("B has dimensions $(size(B)) but needs ($k,$n)") CUSPARSE.bmm!('N', 'N', α, A, B, β, C, 'O')
35+
end
36+
1837
@testset "C = αAB + βC" begin
1938
A1 = CuSparseMatrixCSR{elty}(sprand(elty, m, k, p))
2039
A2 = copy(A1)

test/libraries/cusparse/generic.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,28 @@ using LinearAlgebra
66

77
if CUSPARSE.version() >= v"11.4.1"
88
@testset "generic mv!" for T in [Float32, Float64]
9-
A = sprand(T, 10, 10, 0.1)
10-
x = rand(Complex{T}, 10)
9+
m = 10
10+
A = sprand(T, m, m, 0.1)
11+
x = rand(Complex{T}, m)
1112
y = similar(x)
1213
dx = adapt(CuArray, x)
1314
dy = adapt(CuArray, y)
1415

1516
dA = adapt(CuArray, A)
16-
mv!('N', T(1.0), dA, dx, T(0.0), dy, 'O')
17+
mv!('N', one(T), dA, dx, zero(T), dy, 'O')
1718
@test Array(dy) A * x
1819

1920
dA = CuSparseMatrixCSR(dA)
20-
mv!('N', T(1.0), dA, dx, T(0.0), dy, 'O')
21+
mv!('N', one(T), dA, dx, zero(T), dy, 'O')
2122
@test Array(dy) A * x
23+
24+
A_bad = sprand(T, m+1, m, 0.1)
25+
dA_bad = adapt(CuArray, A_bad)
26+
@test_throws DimensionMismatch("Y must have length $(m+1), but has length $m") mv!('N', one(T), dA_bad, dx, zero(T), dy, 'O')
27+
28+
A_bad = sprand(T, m, m+1, 0.1)
29+
dA_bad = adapt(CuArray, A_bad)
30+
@test_throws DimensionMismatch("X must have length $(m+1), but has length $m") mv!('N', one(T), dA_bad, dx, zero(T), dy, 'O')
2231
end
2332
end
2433

0 commit comments

Comments
 (0)