Skip to content

Commit 8170267

Browse files
CUSPARSE: Improve support for UniformScaling and Diagonal (#1941)
1 parent bf813a5 commit 8170267

File tree

3 files changed

+93
-47
lines changed

3 files changed

+93
-47
lines changed

lib/cusparse/conversions.jl

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
export sort_csc, sort_csr, sort_coo
22

3+
adjtrans_wrappers = ((identity, identity),
4+
(M -> :(Transpose{T, <:$M}), M -> :(_sptranspose(parent($M)))),
5+
(M -> :(Adjoint{T, <:$M}), M -> :(_spadjoint(parent($M)))))
6+
37
# conversion routines between different sparse and dense storage formats
48

59
"""
@@ -45,6 +49,13 @@ function SparseArrays.sparse(I::CuVector{Cint}, J::CuVector{Cint}, V::CuVector{T
4549
end
4650
end
4751

52+
for (wrapa, unwrapa) in adjtrans_wrappers
53+
for SparseMatrixType in (:(CuSparseMatrixCSC{T}), :(CuSparseMatrixCSR{T}), :(CuSparseMatrixCOO{T}))
54+
TypeA = wrapa(SparseMatrixType)
55+
@eval SparseArrays.sparse(A::$TypeA) where {T} = $(unwrapa(:A))
56+
end
57+
end
58+
4859
function sort_csc(A::CuSparseMatrixCSC{Tv,Ti}, index::SparseChar='O') where {Tv,Ti}
4960

5061
m,n = size(A)
@@ -214,16 +225,24 @@ function CuSparseMatrixCSR{T}(S::Adjoint{T, <:CuSparseMatrixCSC{T}}) where {T <:
214225
return CuSparseMatrixCSR{T}(csc.colPtr, csc.rowVal, conj.(csc.nzVal), size(csc))
215226
end
216227

217-
for SparseMatrixType in [:CuSparseMatrixCSC, :CuSparseMatrixCSR]
228+
for SparseMatrixType in (:CuSparseMatrixCSC, :CuSparseMatrixCSR, :CuSparseMatrixCOO)
218229
@eval begin
219-
$SparseMatrixType(S::Diagonal) = $SparseMatrixType(cu(S))
220-
$SparseMatrixType(S::Diagonal{T, <:CuArray}) where T = $SparseMatrixType{T}(S)
221-
$SparseMatrixType{Tv}(S::Diagonal{T, <:CuArray}) where {Tv, T} = $SparseMatrixType{Tv, Cint}(S)
222-
function $SparseMatrixType{Tv, Ti}(S::Diagonal{T, <:CuArray}) where {Tv, Ti, T}
230+
$SparseMatrixType(S::Diagonal{Tv, <:AbstractVector}) where {Tv} = $SparseMatrixType(cu(S))
231+
$SparseMatrixType(S::Diagonal{Tv, <:CuArray}) where Tv = $SparseMatrixType{Tv}(S)
232+
$SparseMatrixType{Tv}(S::Diagonal) where {Tv} = $SparseMatrixType{Tv, Cint}(S)
233+
end
234+
235+
if SparseMatrixType == :CuSparseMatrixCOO
236+
@eval function $SparseMatrixType{Tv, Ti}(S::Diagonal) where {Tv, Ti}
223237
m = size(S, 1)
224-
return $SparseMatrixType{Tv, Ti}(CuVector(1:(m+1)), CuVector(1:m), Tv.(S.diag), (m, m))
238+
return $SparseMatrixType{Tv, Ti}(CuVector(1:m), CuVector(1:m), convert(CuVector{Tv}, S.diag), (m, m))
225239
end
226-
end
240+
else
241+
@eval function $SparseMatrixType{Tv, Ti}(S::Diagonal) where {Tv, Ti}
242+
m = size(S, 1)
243+
return $SparseMatrixType{Tv, Ti}(CuVector(1:(m+1)), CuVector(1:m), convert(CuVector{Tv}, S.diag), (m, m))
244+
end
245+
end
227246
end
228247

229248
# by flipping rows and columns, we can use that to get CSC to CSR too

lib/cusparse/interfaces.jl

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -409,20 +409,52 @@ function _sparse_identity(::Type{<:CuSparseMatrixCSC{<:Any,Ti}},
409409
CuSparseMatrixCSC{Tv,Ti}(colPtr, rowVal, nzVal, dims)
410410
end
411411

412-
Base.:(+)(A::Union{CuSparseMatrixCSR,CuSparseMatrixCSC}, J::UniformScaling) =
413-
A .+ _sparse_identity(typeof(A), J, size(A))
412+
function _sparse_identity(::Type{<:CuSparseMatrixCOO{Tv,Ti}},
413+
I::UniformScaling, dims::Dims) where {Tv,Ti}
414+
len = min(dims[1], dims[2])
415+
rowInd = CuVector{Ti}(1:len)
416+
colInd = CuVector{Ti}(1:len)
417+
nzVal = CUDA.fill(I.λ, len)
418+
CuSparseMatrixCOO{Tv,Ti}(rowInd, colInd, nzVal, dims)
419+
end
414420

415-
Base.:(-)(J::UniformScaling, A::Union{CuSparseMatrixCSR,CuSparseMatrixCSC}) =
416-
_sparse_identity(typeof(A), J, size(A)) .- A
421+
for (wrapa, unwrapa) in adjtrans_wrappers
422+
for SparseMatrixType in (:(CuSparseMatrixCSC{T}), :(CuSparseMatrixCSR{T}), :(CuSparseMatrixCOO{T}))
423+
TypeA = wrapa(SparseMatrixType)
424+
@eval begin
425+
Base.:(+)(A::$TypeA, J::UniformScaling) where {T} = $(unwrapa(:A)) + _sparse_identity(typeof(A), J, size(A))
426+
Base.:(+)(J::UniformScaling, A::$TypeA) where {T} = _sparse_identity(typeof(A), J, size(A)) + $(unwrapa(:A))
417427

418-
# TODO: let Broadcast handle this automatically (a la SparseArrays.PromoteToSparse)
419-
for SparseMatrixType in [:CuSparseMatrixCSC, :CuSparseMatrixCSR], op in [:(+), :(-)]
420-
@eval begin
421-
function Base.$op(lhs::Diagonal{T,<:CuArray}, rhs::$SparseMatrixType{T}) where {T}
422-
return $op($SparseMatrixType(lhs), rhs)
428+
Base.:(-)(A::$TypeA, J::UniformScaling) where {T} = $(unwrapa(:A)) - _sparse_identity(typeof(A), J, size(A))
429+
Base.:(-)(J::UniformScaling, A::$TypeA) where {T} = _sparse_identity(typeof(A), J, size(A)) - $(unwrapa(:A))
423430
end
424-
function Base.$op(lhs::$SparseMatrixType{T}, rhs::Diagonal{T,<:CuArray}) where {T}
425-
return $op(lhs, $SparseMatrixType(rhs))
431+
432+
# Broadcasting is not yet supported for COO matrices
433+
if SparseMatrixType != :(CuSparseMatrixCOO{T})
434+
@eval begin
435+
Base.:(*)(A::$TypeA, J::UniformScaling) where {T} = $(unwrapa(:A)) * J.λ
436+
Base.:(*)(J::UniformScaling, A::$TypeA) where {T} = J.λ * $(unwrapa(:A))
437+
end
438+
else
439+
@eval begin
440+
Base.:(*)(A::$TypeA, J::UniformScaling) where {T} = $(unwrapa(:A)) * _sparse_identity(typeof(A), J, size(A))
441+
Base.:(*)(J::UniformScaling, A::$TypeA) where {T} = _sparse_identity(typeof(A), J, size(A)) * $(unwrapa(:A))
442+
end
443+
end
444+
end
445+
end
446+
447+
# TODO: let Broadcast handle this automatically (a la SparseArrays.PromoteToSparse)
448+
for (wrapa, unwrapa) in adjtrans_wrappers, op in (:(+), :(-), :(*))
449+
for SparseMatrixType in (:(CuSparseMatrixCSC{T}), :(CuSparseMatrixCSR{T}), :(CuSparseMatrixCOO{T}))
450+
TypeA = wrapa(SparseMatrixType)
451+
@eval begin
452+
function Base.$op(lhs::Diagonal, rhs::$TypeA) where {T}
453+
return $op($SparseMatrixType(lhs), $(unwrapa(:rhs)))
454+
end
455+
function Base.$op(lhs::$TypeA, rhs::Diagonal) where {T}
456+
return $op($(unwrapa(:lhs)), $SparseMatrixType(rhs))
457+
end
426458
end
427459
end
428460
end

test/libraries/cusparse/interfaces.jl

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -319,37 +319,32 @@ using LinearAlgebra, SparseArrays
319319
@test SparseMatrixCSC(CuSparseMatrixCSR(T)) f(S)
320320
end
321321

322-
VERSION >= v"1.7" && @testset "UniformScaling with $typ($dims)" for
323-
typ in [CuSparseMatrixCSR, CuSparseMatrixCSC],
324-
dims in [(10, 10), (5, 10), (10, 5)]
325-
S = sprand(Float32, dims..., 0.1)
326-
dA = typ(S)
327-
328-
@test Array(dA + I) == S + I
329-
@test Array(I + dA) == I + S
330-
331-
@test Array(dA - I) == S - I
332-
@test Array(I - dA) == I - S
322+
VERSION >= v"1.7" && @testset "UniformScaling basic operations" begin
323+
for elty in (Float32, Float64, ComplexF32, ComplexF64)
324+
A = sprand(elty, 100, 100, 0.1)
325+
U1 = 2*I
326+
for SparseMatrixType in (CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO)
327+
B = SparseMatrixType(A)
328+
for op in (+, -, *)
329+
@test Array(op(B, U1)) op(A, U1) && Array(op(U1, B)) op(U1, A)
330+
end
331+
end
332+
end
333333
end
334334

335-
@testset "Diagonal with $typ(10, 10)" for
336-
typ in [CuSparseMatrixCSR, CuSparseMatrixCSC]
337-
338-
S = sprand(Float32, 10, 10, 0.8)
339-
D = Diagonal(rand(Float32, 10))
340-
dA = typ(S)
341-
dD = adapt(CuArray, D)
342-
343-
@test Array(dA + dD) == S + D
344-
@test Array(dD + dA) == D + S
345-
346-
@test Array(dA - dD) == S - D
347-
@test Array(dD - dA) == D - S
348-
349-
@test dA + dD isa typ
350-
@test dD + dA isa typ
351-
@test dA - dD isa typ
352-
@test dD - dA isa typ
335+
@testset "Diagonal basic operations" begin
336+
for elty in (Float32, Float64, ComplexF32, ComplexF64)
337+
A = sprand(elty, 100, 100, 0.1)
338+
U2 = 2*I(100)
339+
U3 = Diagonal(rand(elty, 100))
340+
for SparseMatrixType in (CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO)
341+
B = SparseMatrixType(A)
342+
for op in (+, -, *)
343+
@test Array(op(B, U2)) op(A, U2) && Array(op(U2, B)) op(U2, A)
344+
@test Array(op(B, U3)) op(A, U3) && Array(op(U3, B)) op(U3, A)
345+
end
346+
end
347+
end
353348
end
354349

355350
@testset "dot(CuVector, CuSparseVector) and dot(CuSparseVector, CuVector) $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64],

0 commit comments

Comments
 (0)