Skip to content

Commit 2987086

Browse files
authored
Expose sparse mv/mm algo selection (#1201)
1 parent 28b1ff4 commit 2987086

File tree

2 files changed

+57
-12
lines changed

2 files changed

+57
-12
lines changed

lib/cusparse/generic.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ function gather!(X::CuSparseVector, Y::CuVector, index::SparseChar)
110110
end
111111

112112
function mv!(transa::SparseChar, alpha::Number, A::Union{CuSparseMatrixBSR{T},CuSparseMatrixCSR{T}},
113-
X::DenseCuVector{T}, beta::Number, Y::DenseCuVector{T}, index::SparseChar) where {T}
113+
X::DenseCuVector{T}, beta::Number, Y::DenseCuVector{T}, index::SparseChar, algo::cusparseSpMVAlg_t=CUSPARSE_MV_ALG_DEFAULT) where {T}
114114
m,n = size(A)
115115

116116
if transa == 'N'
@@ -127,18 +127,18 @@ function mv!(transa::SparseChar, alpha::Number, A::Union{CuSparseMatrixBSR{T},Cu
127127
function bufferSize()
128128
out = Ref{Csize_t}()
129129
cusparseSpMV_bufferSize(handle(), transa, Ref{compute_type}(alpha), descA, descX, Ref{compute_type}(beta),
130-
descY, compute_type, CUSPARSE_SPMV_ALG_DEFAULT, out)
130+
descY, compute_type, algo, out)
131131
return out[]
132132
end
133133
with_workspace(bufferSize) do buffer
134134
cusparseSpMV(handle(), transa, Ref{compute_type}(alpha), descA, descX, Ref{compute_type}(beta),
135-
descY, compute_type, CUSPARSE_SPMV_ALG_DEFAULT, buffer)
135+
descY, compute_type, algo, buffer)
136136
end
137137
Y
138138
end
139139

140140
function mv!(transa::SparseChar, alpha::Number, A::CuSparseMatrixCSC{T}, X::DenseCuVector{T},
141-
beta::Number, Y::DenseCuVector{T}, index::SparseChar) where {T}
141+
beta::Number, Y::DenseCuVector{T}, index::SparseChar, algo::cusparseSpMVAlg_t=CUSPARSE_MV_ALG_DEFAULT) where {T}
142142
ctransa = 'N'
143143
if transa == 'N'
144144
ctransa = 'T'
@@ -163,19 +163,19 @@ function mv!(transa::SparseChar, alpha::Number, A::CuSparseMatrixCSC{T}, X::Dens
163163
function bufferSize()
164164
out = Ref{Csize_t}()
165165
cusparseSpMV_bufferSize(handle(), ctransa, Ref{compute_type}(alpha), descA, descX, Ref{compute_type}(beta),
166-
descY, compute_type, CUSPARSE_SPMV_ALG_DEFAULT, out)
166+
descY, compute_type, algo, out)
167167
return out[]
168168
end
169169
with_workspace(bufferSize) do buffer
170170
cusparseSpMV(handle(), ctransa, Ref{compute_type}(alpha), descA, descX, Ref{compute_type}(beta),
171-
descY, compute_type, CUSPARSE_SPMV_ALG_DEFAULT, buffer)
171+
descY, compute_type, algo, buffer)
172172
end
173173

174174
return Y
175175
end
176176

177177
function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrixCSR{T},
178-
B::DenseCuMatrix{T}, beta::Number, C::DenseCuMatrix{T}, index::SparseChar) where {T}
178+
B::DenseCuMatrix{T}, beta::Number, C::DenseCuMatrix{T}, index::SparseChar, algo::cusparseSpMMAlg_t=CUSPARSE_MM_ALG_DEFAULT) where {T}
179179
m,k = size(A)
180180
n = size(C)[2]
181181

@@ -197,20 +197,20 @@ function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseM
197197
out = Ref{Csize_t}()
198198
cusparseSpMM_bufferSize(
199199
handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta),
200-
descC, T, CUSPARSE_SPMM_ALG_DEFAULT, out)
200+
descC, T, algo, out)
201201
return out[]
202202
end
203203
with_workspace(bufferSize) do buffer
204204
cusparseSpMM(
205205
handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta),
206-
descC, T, CUSPARSE_SPMM_ALG_DEFAULT, buffer)
206+
descC, T, algo, buffer)
207207
end
208208

209209
return C
210210
end
211211

212212
function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrixCSC{T},
213-
B::DenseCuMatrix{T}, beta::Number, C::DenseCuMatrix{T}, index::SparseChar) where {T}
213+
B::DenseCuMatrix{T}, beta::Number, C::DenseCuMatrix{T}, index::SparseChar, algo::cusparseSpMMAlg_t=CUSPARSE_MM_ALG_DEFAULT) where {T}
214214
ctransa = 'N'
215215
if transa == 'N'
216216
ctransa = 'T'
@@ -240,13 +240,13 @@ function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseM
240240
out = Ref{Csize_t}()
241241
cusparseSpMM_bufferSize(
242242
handle(), ctransa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta),
243-
descC, T, CUSPARSE_SPMM_ALG_DEFAULT, out)
243+
descC, T, algo, out)
244244
return out[]
245245
end
246246
with_workspace(bufferSize) do buffer
247247
cusparseSpMM(
248248
handle(), ctransa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta),
249-
descC, T, CUSPARSE_SPMM_ALG_DEFAULT, buffer)
249+
descC, T, algo, buffer)
250250
end
251251

252252
return C

test/cusparse/generic.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
using CUDA.CUSPARSE, SparseArrays
2+
3+
if CUSPARSE.version() >= v"11.4.1" # lower CUDA version doesn't support these algorithms
4+
5+
@testset "mm algo=$algo" for algo in [
6+
CUSPARSE.CUSPARSE_SPMM_ALG_DEFAULT,
7+
CUSPARSE.CUSPARSE_SPMM_CSR_ALG1,
8+
CUSPARSE.CUSPARSE_SPMM_CSR_ALG2,
9+
CUSPARSE.CUSPARSE_SPMM_CSR_ALG3,
10+
]
11+
@testset "mm $T" for T in [Float32, Float64, ComplexF32, ComplexF64]
12+
A = sprand(T, 10, 10, 0.1)
13+
B = rand(T, 10, 2)
14+
C = rand(T, 10, 2)
15+
dA = CuSparseMatrixCSR(A)
16+
dB = CuArray(B)
17+
dC = CuArray(C)
18+
19+
alpha = 1.2
20+
beta = 1.3
21+
mm!('N', 'N', alpha, dA, dB, beta, dC, 'O', algo)
22+
@test alpha * A * B + beta * C collect(dC)
23+
end
24+
end
25+
26+
@testset "mv algo=$algo" for algo in [
27+
CUSPARSE.CUSPARSE_SPMV_ALG_DEFAULT,
28+
CUSPARSE.CUSPARSE_SPMV_CSR_ALG1,
29+
CUSPARSE.CUSPARSE_SPMV_CSR_ALG2,
30+
]
31+
@testset "mv $T" for T in [Float32, Float64, ComplexF32, ComplexF64]
32+
A = sprand(T, 10, 10, 0.1)
33+
B = rand(T, 10)
34+
C = rand(T, 10)
35+
dA = CuSparseMatrixCSR(A)
36+
dB = CuArray(B)
37+
dC = CuArray(C)
38+
39+
alpha = 1.2
40+
beta = 1.3
41+
mv!('N', alpha, dA, dB, beta, dC, 'O', algo)
42+
@test alpha * A * B + beta * C collect(dC)
43+
end
44+
end
45+
end # version >= 11.4.1

0 commit comments

Comments
 (0)