Skip to content

Commit 1428100

Browse files
authored
[CUSPARSE] Add sddmm! and gemvi! routines (#1615)
1 parent a753cba commit 1428100

File tree

6 files changed

+235
-22
lines changed

6 files changed

+235
-22
lines changed

lib/cusparse/generic.jl

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# generic APIs
22

3-
export vv!, sv!, sm!, gemm, gemm!
3+
export gather!, scatter!, axpby!, rot!
4+
export vv!, sv!, sm!, gemm, gemm!, sddmm!
45

56
## API functions
67

@@ -216,6 +217,11 @@ function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::Union{CuS
216217
return out[]
217218
end
218219
with_workspace(bufferSize) do buffer
220+
# Uncomment if we find a way to reuse the buffer (issue #1362)
221+
# cusparseSpMM_preprocess(
222+
# handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta),
223+
# descC, T, algo, buffer)
224+
# end
219225
cusparseSpMM(
220226
handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta),
221227
descC, T, algo, buffer)
@@ -451,3 +457,40 @@ function sm!(transa::SparseChar, transb::SparseChar, uplo::SparseChar, diag::Spa
451457
end
452458
return C
453459
end
460+
461+
function sddmm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::DenseCuMatrix{T}, B::DenseCuMatrix{T},
462+
beta::Number, C::CuSparseMatrixCSR{T}, index::SparseChar, algo::cusparseSDDMMAlg_t=CUSPARSE_SDDMM_ALG_DEFAULT) where {T}
463+
464+
# Support transa = 'C' and `transb = 'C' for real matrices
465+
transa = T <: Real && transa == 'C' ? 'T' : transa
466+
transb = T <: Real && transb == 'C' ? 'T' : transb
467+
468+
m,k = size(A)
469+
n = size(C)[2]
470+
471+
if transa == 'N' && transb == 'N'
472+
chkmmdims(B,C,k,n,m,n)
473+
elseif transa == 'N' && transb != 'N'
474+
chkmmdims(B,C,n,k,m,n)
475+
elseif transa != 'N' && transb == 'N'
476+
chkmmdims(B,C,m,n,k,n)
477+
elseif transa != 'N' && transb != 'N'
478+
chkmmdims(B,C,n,m,k,n)
479+
end
480+
481+
descA = CuDenseMatrixDescriptor(A)
482+
descB = CuDenseMatrixDescriptor(B)
483+
descC = CuSparseMatrixDescriptor(C, index)
484+
485+
function bufferSize()
486+
out = Ref{Csize_t}()
487+
cusparseSDDMM_bufferSize(handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta), descC, T, algo, out)
488+
return out[]
489+
end
490+
with_workspace(bufferSize) do buffer
491+
# Uncomment if we find a way to reuse the buffer (issue #1362)
492+
# cusparseSDDMM_preprocess(handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta), descC, T, algo, buffer)
493+
cusparseSDDMM(handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta), descC, T, algo, buffer)
494+
end
495+
return C
496+
end

lib/cusparse/interfaces.jl

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ for (taga, untaga) in tag_wrappers, (wrapa, transa, unwrapa) in op_wrappers
4141
mv_wrapper($transa(T), alpha, $(untaga(unwrapa(:A))), B, beta, C)
4242
end
4343

44-
function LinearAlgebra.mul!(C::CuVector{Complex{T}}, A::$TypeA, B::DenseCuVector{Complex{T}},
45-
alpha::Number, beta::Number) where {T <: Union{Float16, BlasFloat}}
46-
mv_wrapper($transa(T), alpha, $(untaga(unwrapa(:A))), B, beta, C)
44+
function LinearAlgebra.mul!(C::CuVector{T}, A::$TypeA, B::CuSparseVector{T},
45+
alpha::Number, beta::Number) where {T <: Union{Float16, ComplexF16, BlasFloat}}
46+
mv_wrapper($transa(T), alpha, $(untaga(unwrapa(:A))), CuVector{T}(B), beta, C)
4747
end
4848
end
4949

@@ -59,6 +59,22 @@ for (taga, untaga) in tag_wrappers, (wrapa, transa, unwrapa) in op_wrappers
5959
end
6060
end
6161

62+
for (taga, untaga) in tag_wrappers, (wrapa, transa, unwrapa) in op_wrappers
63+
TypeA = wrapa(taga(:(DenseCuMatrix{T})))
64+
@eval begin
65+
function LinearAlgebra.mul!(C::CuVector{T}, A::$TypeA, B::CuSparseVector{T}, alpha::Number, beta::Number) where {T <: BlasFloat}
66+
gemvi!($transa(T), alpha, $(untaga(unwrapa(:A))), B, beta, C, 'O')
67+
end
68+
69+
function SparseArrays.:(*)(A::$TypeA, x::CuSparseVector{T}) where {T <: BlasFloat}
70+
m, n = size(A)
71+
length(x) == n || throw(DimensionMismatch())
72+
y = CuVector{T}(undef, m)
73+
mul!(y, A, x)
74+
end
75+
end
76+
end
77+
6278
Base.:(+)(A::CuSparseMatrixCSR, B::CuSparseMatrixCSR) = geam(one(eltype(A)), A, one(eltype(A)), B, 'O')
6379
Base.:(-)(A::CuSparseMatrixCSR, B::CuSparseMatrixCSR) = geam(one(eltype(A)), A, -one(eltype(A)), B, 'O')
6480

@@ -143,6 +159,14 @@ for (t, uploc, isunitc) in ((:LowerTriangular, 'L', 'N'),
143159
LinearAlgebra.ldiv!(A::$t{T,<:AbstractCuSparseMatrix},
144160
B::DenseCuMatrix{T}) where {T<:BlasFloat} =
145161
sm2!('N', 'N', $uploc, $isunitc, one(T), parent(A), B, 'O')
162+
163+
LinearAlgebra.ldiv!(A::$t{T,<:AbstractCuSparseMatrix},
164+
B::Transpose{T,<:DenseCuMatrix}) where {T<:BlasFloat} =
165+
sm2!('N', 'T', $uploc, $isunitc, one(T), parent(A), parent(B), 'O')
166+
167+
LinearAlgebra.ldiv!(A::$t{T,<:AbstractCuSparseMatrix},
168+
B::Adjoint{T,<:DenseCuMatrix}) where {T<:BlasFloat} =
169+
sm2!('N', 'C', $uploc, $isunitc, one(T), parent(A), parent(B), 'O')
146170
end
147171
end
148172

@@ -153,30 +177,38 @@ for (t, uploc, isunitc) in ((:LowerTriangular, 'U', 'N'),
153177
(:UnitUpperTriangular, 'L', 'U'))
154178
@eval begin
155179
# Left division with vectors
156-
LinearAlgebra.ldiv!(A::$t{<:Any,<:Transpose{T,<:AbstractCuSparseMatrix}},
180+
LinearAlgebra.ldiv!(A::$t{T,<:Transpose{T,<:AbstractCuSparseMatrix}},
157181
B::DenseCuVector{T}) where {T<:BlasFloat} =
158182
sv2!('T', $uploc, $isunitc, one(T), parent(parent(A)), B, 'O')
159183

160-
LinearAlgebra.ldiv!(A::$t{<:Any,<:Adjoint{T,<:AbstractCuSparseMatrix}},
161-
B::DenseCuVector{T}) where {T<:BlasReal} =
162-
sv2!('T', $uploc, $isunitc, one(T), parent(parent(A)), B, 'O')
163-
164-
LinearAlgebra.ldiv!(A::$t{<:Any,<:Adjoint{T,<:AbstractCuSparseMatrix}},
165-
B::DenseCuVector{T}) where {T<:BlasComplex} =
184+
LinearAlgebra.ldiv!(A::$t{T,<:Adjoint{T,<:AbstractCuSparseMatrix}},
185+
B::DenseCuVector{T}) where {T<:BlasFloat} =
166186
sv2!('C', $uploc, $isunitc, one(T), parent(parent(A)), B, 'O')
167187

168188
# Left division with matrices
169-
LinearAlgebra.ldiv!(A::$t{<:Any,<:Transpose{T,<:AbstractCuSparseMatrix}},
189+
LinearAlgebra.ldiv!(A::$t{T,<:Transpose{T,<:AbstractCuSparseMatrix}},
170190
B::DenseCuMatrix{T}) where {T<:BlasFloat} =
171191
sm2!('T', 'N', $uploc, $isunitc, one(T), parent(parent(A)), B, 'O')
172192

173-
LinearAlgebra.ldiv!(A::$t{<:Any,<:Adjoint{T,<:AbstractCuSparseMatrix}},
174-
B::DenseCuMatrix{T}) where {T<:BlasReal} =
175-
sm2!('T', 'N', $uploc, $isunitc, one(T), parent(parent(A)), B, 'O')
193+
LinearAlgebra.ldiv!(A::$t{T,<:Transpose{T,<:AbstractCuSparseMatrix}},
194+
B::Transpose{T,<:DenseCuMatrix}) where {T<:BlasFloat} =
195+
sm2!('T', 'T', $uploc, $isunitc, one(T), parent(parent(A)), parent(B), 'O')
196+
197+
LinearAlgebra.ldiv!(A::$t{T,<:Transpose{T,<:AbstractCuSparseMatrix}},
198+
B::Adjoint{T,<:DenseCuMatrix}) where {T<:BlasFloat} =
199+
sm2!('T', 'C', $uploc, $isunitc, one(T), parent(parent(A)), parent(B), 'O')
176200

177-
LinearAlgebra.ldiv!(A::$t{<:Any,<:Adjoint{T,<:AbstractCuSparseMatrix}},
178-
B::DenseCuMatrix{T}) where {T<:BlasComplex} =
201+
LinearAlgebra.ldiv!(A::$t{T,<:Adjoint{T,<:AbstractCuSparseMatrix}},
202+
B::DenseCuMatrix{T}) where {T<:BlasFloat} =
179203
sm2!('C', 'N', $uploc, $isunitc, one(T), parent(parent(A)), B, 'O')
204+
205+
LinearAlgebra.ldiv!(A::$t{T,<:Adjoint{T,<:AbstractCuSparseMatrix}},
206+
B::Transpose{T,<:DenseCuMatrix}) where {T<:BlasFloat} =
207+
sm2!('C', 'T', $uploc, $isunitc, one(T), parent(parent(A)), parent(B), 'O')
208+
209+
LinearAlgebra.ldiv!(A::$t{T,<:Adjoint{T,<:AbstractCuSparseMatrix}},
210+
B::Adjoint{T,<:DenseCuMatrix}) where {T<:BlasFloat} =
211+
sm2!('C', 'C', $uploc, $isunitc, one(T), parent(parent(A)), parent(B), 'O')
180212
end
181213
end
182214

lib/cusparse/level2.jl

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# sparse linear algebra functions that perform operations between sparse matrices and dense
22
# vectors
33

4-
export sv2!, sv2, mv!
4+
export sv2!, sv2, mv!, gemvi!
55

66
"""
77
mv!(transa::SparseChar, alpha::Number, A::CuSparseMatrix, X::CuVector, beta::Number, Y::CuVector, index::SparseChar)
@@ -24,6 +24,10 @@ for (fname,elty) in ((:cusparseSbsrmv, :Float32),
2424
beta::Number,
2525
Y::CuVector{$elty},
2626
index::SparseChar)
27+
28+
# Support transa = 'C' for real matrices
29+
transa = $elty <: Real && transa == 'C' ? 'T' : transa
30+
2731
desc = CuMatrixDescriptor('G', 'L', 'N', index)
2832
m,n = size(A)
2933
mb = div(m,A.blockDim)
@@ -65,6 +69,10 @@ for (bname,aname,sname,elty) in ((:cusparseSbsrsv2_bufferSize, :cusparseSbsrsv2_
6569
A::CuSparseMatrixBSR{$elty},
6670
X::CuVector{$elty},
6771
index::SparseChar)
72+
73+
# Support transa = 'C' for real matrices
74+
transa = $elty <: Real && transa == 'C' ? 'T' : transa
75+
6876
desc = CuMatrixDescriptor('G', uplo, diag, index)
6977
m,n = size(A)
7078
if m != n
@@ -118,6 +126,10 @@ for (bname,aname,sname,elty) in ((:cusparseScsrsv2_bufferSize, :cusparseScsrsv2_
118126
A::CuSparseMatrixCSR{$elty},
119127
X::CuVector{$elty},
120128
index::SparseChar)
129+
130+
# Support transa = 'C' for real matrices
131+
transa = $elty <: Real && transa == 'C' ? 'T' : transa
132+
121133
desc = CuMatrixDescriptor('G', uplo, diag, index)
122134
m,n = size(A)
123135
if m != n
@@ -170,11 +182,15 @@ for (bname,aname,sname,elty) in ((:cusparseScsrsv2_bufferSize, :cusparseScsrsv2_
170182
A::CuSparseMatrixCSC{$elty},
171183
X::CuVector{$elty},
172184
index::SparseChar)
185+
186+
# Support transa = 'C' for real matrices
187+
transa = $elty <: Real && transa == 'C' ? 'T' : transa
188+
173189
ctransa = 'N'
174190
cuplo = 'U'
175191
if transa == 'N'
176192
ctransa = 'T'
177-
elseif transa == 'C' && eltype(A) <: Complex
193+
elseif transa == 'C' && $elty <: Complex
178194
throw(ArgumentError("Backward and forward sweeps with the adjoint of a complex CSC matrix is not supported. Use a CSR matrix instead."))
179195
end
180196
if uplo == 'U'
@@ -247,3 +263,37 @@ for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
247263
end
248264
end
249265
end
266+
267+
# gemvi!
268+
for (bname,fname,elty) in ((:cusparseSgemvi_bufferSize, :cusparseSgemvi, :Float32),
269+
(:cusparseDgemvi_bufferSize, :cusparseDgemvi, :Float64),
270+
(:cusparseCgemvi_bufferSize, :cusparseCgemvi, :ComplexF32),
271+
(:cusparseZgemvi_bufferSize, :cusparseZgemvi, :ComplexF64))
272+
@eval begin
273+
function gemvi!(transa::SparseChar,
274+
alpha::Number,
275+
A::CuMatrix{$elty},
276+
X::CuSparseVector{$elty},
277+
beta::Number,
278+
Y::CuVector{$elty},
279+
index::SparseChar)
280+
281+
# Support transa = 'C' for real matrices
282+
transa = $elty <: Real && transa == 'C' ? 'T' : transa
283+
284+
m,n = size(A)
285+
lda = max(1,stride(A,2))
286+
nnzX = nnz(X)
287+
288+
function bufferSize()
289+
out = Ref{Cint}()
290+
$bname(handle(), transa, m, n, nnzX, out)
291+
return out[]
292+
end
293+
with_workspace(bufferSize) do buffer
294+
$fname(handle(), transa, m, n, alpha, A, lda, nnzX, nonzeros(X), nonzeroinds(X), beta, Y, index, buffer)
295+
end
296+
Y
297+
end
298+
end
299+
end

lib/cusparse/level3.jl

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ for (fname,elty) in ((:cusparseSbsrmm, :Float32),
2626
beta::Number,
2727
C::StridedCuMatrix{$elty},
2828
index::SparseChar)
29+
30+
# Support transa = 'C' and `transb = 'C' for real matrices
31+
transa = $elty <: Real && transa == 'C' ? 'T' : transa
32+
transb = $elty <: Real && transb == 'C' ? 'T' : transb
33+
2934
desc = CuMatrixDescriptor('G', 'L', 'N', index)
3035
m,k = size(A)
3136
mb = cld(m, A.blockDim)
@@ -64,6 +69,11 @@ for (fname,elty) in ((:cusparseScsrmm2, :Float32),
6469
beta::Number,
6570
C::StridedCuMatrix{$elty},
6671
index::SparseChar)
72+
73+
# Support transa = 'C' and `transb = 'C' for real matrices
74+
transa = $elty <: Real && transa == 'C' ? 'T' : transa
75+
transb = $elty <: Real && transb == 'C' ? 'T' : transb
76+
6777
if transb == 'C'
6878
throw(ArgumentError("B^H is not supported"))
6979
elseif transb == 'T' && transa != 'N'
@@ -96,6 +106,11 @@ for (fname,elty) in ((:cusparseScsrmm2, :Float32),
96106
beta::Number,
97107
C::StridedCuMatrix{$elty},
98108
index::SparseChar)
109+
110+
# Support transa = 'C' and `transb = 'C' for real matrices
111+
transa = $elty <: Real && transa == 'C' ? 'T' : transa
112+
transb = $elty <: Real && transb == 'C' ? 'T' : transb
113+
99114
if transb == 'C'
100115
throw(ArgumentError("B^H is not supported"))
101116
elseif transb == 'T' && transa != 'N'
@@ -151,6 +166,11 @@ for (bname,aname,sname,elty) in ((:cusparseSbsrsm2_bufferSize, :cusparseSbsrsm2_
151166
A::CuSparseMatrixBSR{$elty},
152167
X::StridedCuMatrix{$elty},
153168
index::SparseChar)
169+
170+
# Support transa = 'C' and transxy = 'C' for real matrices
171+
transa = $elty <: Real && transa == 'C' ? 'T' : transa
172+
transxy = $elty <: Real && transxy == 'C' ? 'T' : transxy
173+
154174
desc = CuMatrixDescriptor('G', uplo, diag, index)
155175
m,n = size(A)
156176
if m != n
@@ -211,6 +231,11 @@ for (bname,aname,sname,elty) in ((:cusparseScsrsm2_bufferSizeExt, :cusparseScsrs
211231
A::CuSparseMatrixCSR{$elty},
212232
X::StridedCuMatrix{$elty},
213233
index::SparseChar)
234+
235+
# Support transa = 'C' and transxy = 'C' for real matrices
236+
transa = $elty <: Real && transa == 'C' ? 'T' : transa
237+
transxy = $elty <: Real && transxy == 'C' ? 'T' : transxy
238+
214239
desc = CuMatrixDescriptor('G', uplo, diag, index)
215240
m,n = size(A)
216241
if m != n
@@ -271,11 +296,16 @@ for (bname,aname,sname,elty) in ((:cusparseScsrsm2_bufferSizeExt, :cusparseScsrs
271296
A::CuSparseMatrixCSC{$elty},
272297
X::StridedCuMatrix{$elty},
273298
index::SparseChar)
299+
300+
# Support transa = 'C' and transxy = 'C' for real matrices
301+
transa = $elty <: Real && transa == 'C' ? 'T' : transa
302+
transxy = $elty <: Real && transxy == 'C' ? 'T' : transxy
303+
274304
ctransa = 'N'
275305
cuplo = 'U'
276306
if transa == 'N'
277307
ctransa = 'T'
278-
elseif transa == 'C' && eltype(A) <: Complex
308+
elseif transa == 'C' && $elty <: Complex
279309
throw(ArgumentError("Backward and forward sweeps with the adjoint of a complex CSC matrix is not supported. Use a CSR matrix instead."))
280310
end
281311
if uplo == 'U'

test/cusparse.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,3 +1052,22 @@ end
10521052
end
10531053
end
10541054
end
1055+
1056+
@testset "gemvi!" begin
1057+
@testset for elty in [Float32,Float64,ComplexF32,ComplexF64]
1058+
for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)]
1059+
elty <: Complex && transa == 'C' && continue
1060+
A = transa == 'N' ? rand(elty,m,n) : rand(elty,n,m)
1061+
x = sparsevec(rand(1:n,k), rand(elty,k), n)
1062+
y = rand(elty,m)
1063+
dA = CuArray(A)
1064+
dx = CuSparseVector(x)
1065+
dy = CuArray(y)
1066+
alpha = rand(elty)
1067+
beta = rand(elty)
1068+
gemvi!(transa,alpha,dA,dx,beta,dy,'O')
1069+
z = alpha * opa(A) * x + beta * y
1070+
@test z collect(dy)
1071+
end
1072+
end
1073+
end

0 commit comments

Comments
 (0)