Skip to content

Commit cdae2d3

Browse files
authored
Add conversions between CuSparseVector and CuSparseMatrices (#2489)
1 parent 1e1f4ef commit cdae2d3

File tree

2 files changed

+61
-2
lines changed

2 files changed

+61
-2
lines changed

lib/cusparse/conversions.jl

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -614,12 +614,12 @@ function CuSparseMatrixCOO{Tv}(csc::CuSparseMatrixCSC{Tv}; index::SparseChar='O'
614614
coo = sort_coo(coo, 'R')
615615
end
616616

617-
### BSR to COO and viceversa
617+
### BSR to COO and vice-versa
618618

619619
CuSparseMatrixBSR(coo::CuSparseMatrixCOO, blockdim) = CuSparseMatrixBSR(CuSparseMatrixCSR(coo), blockdim) # no direct conversion
620620
CuSparseMatrixCOO(bsr::CuSparseMatrixBSR) = CuSparseMatrixCOO(CuSparseMatrixCSR(bsr)) # no direct conversion
621621

622-
### BSR to CSC and viceversa
622+
### BSR to CSC and vice-versa
623623

624624
CuSparseMatrixBSR(csc::CuSparseMatrixCSC, blockdim) = CuSparseMatrixBSR(CuSparseMatrixCSR(csc), blockdim) # no direct conversion
625625
CuSparseMatrixCSC(bsr::CuSparseMatrixBSR) = CuSparseMatrixCSC(CuSparseMatrixCSR(bsr)) # no direct conversion
@@ -668,3 +668,44 @@ end
668668
function CuSparseMatrixCOO(A::CuMatrix{T}; index::SparseChar='O') where {T}
669669
densetosparse(A, :coo, index)
670670
end
671+
672+
## CuSparseVector to CuSparseMatrices and vice-versa
673+
function CuSparseVector(A::CuSparseMatrixCSC{T}) where T
674+
m, n = size(A)
675+
(n == 1) || error("A doesn't have one column and can't be converted to a CuSparseVector.")
676+
CuSparseVector{T}(A.rowVal, A.nzVal, m)
677+
end
678+
679+
# no direct conversion
680+
function CuSparseVector(A::CuSparseMatrixCSR{T}) where T
681+
m, n = size(A)
682+
(n == 1) || error("A doesn't have one column and can't be converted to a CuSparseVector.")
683+
B = CuSparseMatrixCSC{T}(A)
684+
CuSparseVector(B)
685+
end
686+
687+
function CuSparseVector(A::CuSparseMatrixCOO{T}) where T
688+
m, n = size(A)
689+
(n == 1) || error("A doesn't have one column and can't be converted to a CuSparseVector.")
690+
CuSparseVector{T}(A.rowInd, A.nzVal, m)
691+
end
692+
693+
function CuSparseMatrixCSC(x::CuSparseVector{T}) where T
694+
n = length(x)
695+
colPtr = CuVector{Int32}([1; nnz(x)+1])
696+
CuSparseMatrixCSC{T}(colPtr, x.iPtr, nonzeros(x), (n,1))
697+
end
698+
699+
# no direct conversion
700+
function CuSparseMatrixCSR(x::CuSparseVector{T}) where T
701+
A = CuSparseMatrixCSC(x)
702+
CuSparseMatrixCSR{T}(A)
703+
end
704+
705+
function CuSparseMatrixCOO(x::CuSparseVector{T}) where T
706+
n = length(x)
707+
nnzx = nnz(x)
708+
colInd = CuVector{Int32}(undef, nnzx)
709+
fill!(colInd, one(Int32))
710+
CuSparseMatrixCOO{T}(x.iPtr, colInd, nonzeros(x), (n,1), nnzx)
711+
end

test/libraries/cusparse/conversions.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,24 @@ end
7474
end
7575
end
7676

77+
if !(v"12.0" <= CUSPARSE.version() < v"12.1")
78+
x = [0.0; 1.0; 2.0; 0.0; 3.0] |> SparseVector |> CuSparseVector
79+
A = Matrix{Float64}(undef, 5, 1)
80+
A[:, 1] .= [0.0; 1.0; 2.0; 0.0; 3.0]
81+
A = SparseMatrixCSC(A)
82+
for CuSparseMatrixType in (CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO)
83+
@testset "conversion CuSparseVector --> $CuSparseMatrixType" begin
84+
B = CuSparseMatrixType(x)
85+
@test collect(B)[:] collect(x)
86+
end
87+
@testset "conversion $CuSparseMatrixType --> CuSparseVector" begin
88+
B = CuSparseMatrixType(A)
89+
y = CuSparseVector(B)
90+
@test collect(B)[:] collect(y)
91+
end
92+
end
93+
end
94+
7795
for (n, bd, p) in [(100, 5, 0.02), (5, 1, 0.8), (4, 2, 0.5)]
7896
v"12.0" <= CUSPARSE.version() < v"12.1" && n == 4 && continue
7997
@testset "conversions between CuSparseMatrices (n, bd, p) = ($n, $bd, $p)" begin

0 commit comments

Comments
 (0)