Skip to content

Commit 3122ba8

Browse files
authored
CUSPARSE: Eagerly combine duplicate element on construction. (#2213)
1 parent eef9e3b commit 3122ba8

File tree

2 files changed

+127
-13
lines changed

2 files changed

+127
-13
lines changed

lib/cusparse/conversions.jl

Lines changed: 103 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,113 @@ function SparseArrays.sparse(x::DenseCuMatrix; fmt=:csc)
2727
end
2828
end
2929

30-
SparseArrays.sparse(I::CuVector, J::CuVector, V::CuVector; kws...) =
31-
sparse(I, J, V, maximum(I), maximum(J); kws...)
30+
function SparseArrays.sparse(I::CuVector, J::CuVector, V::CuVector, args...; kwargs...)
31+
sparse(Cint.(I), Cint.(J), V, args...; kwargs...)
32+
end
3233

33-
SparseArrays.sparse(I::CuVector, J::CuVector, V::CuVector, m, n; kws...) =
34-
sparse(Cint.(I), Cint.(J), V, m, n; kws...)
34+
function SparseArrays.sparse(I::CuVector{Cint}, J::CuVector{Cint}, V::CuVector{Tv},
35+
m=maximum(I), n=maximum(J);
36+
fmt=:csc, combine=nothing) where Tv
37+
# we use COO as an intermediate format, as it's easy to construct from I/J/V.
38+
coo = CuSparseMatrixCOO{Tv}(I, J, V, (m, n))
3539

36-
function SparseArrays.sparse(I::CuVector{Cint}, J::CuVector{Cint}, V::CuVector{Tv}, m, n;
37-
fmt=:csc) where Tv
40+
# find groups of values that correspond to the same position in the matrix.
41+
# if there's no duplicates, `groups` will just be a vector of ones.
42+
# otherwise, it will contain the number of duplicates for each group,
43+
# with subsequent values that are part of the group set to zero.
44+
coo = sort_coo(coo, 'R')
45+
groups = similar(I, Int)
46+
function find_groups(groups, I, J)
47+
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x
48+
if i > length(groups)
49+
return
50+
end
51+
len = 0
3852

39-
coo = CuSparseMatrixCOO{Tv}(I, J, V, (m, n))
40-
if fmt == :csc
53+
# check if we're at the start of a new group
54+
@inbounds if i == 1 || I[i] != I[i-1] || J[i] != J[i-1]
55+
len = 1
56+
while i+len <= length(groups) && I[i] == I[i+len] && J[i] == J[i+len]
57+
len += 1
58+
end
59+
end
60+
61+
@inbounds groups[i] = len
62+
63+
return
64+
end
65+
kernel = @cuda launch=false find_groups(groups, coo.rowInd, coo.colInd)
66+
config = launch_configuration(kernel.fun)
67+
threads = min(length(groups), config.threads)
68+
blocks = cld(length(groups), threads)
69+
kernel(groups, coo.rowInd, coo.colInd; threads, blocks)
70+
71+
# if we got any group of more than one element, we need to combine them.
72+
# this may actually not be required, as some CUSPARSE functions can handle
73+
# duplicate entries, but it's not clear which ones do and which ones don't.
74+
# also, to ensure matrix display is correct, combine values eagerly.
75+
ngroups = mapreduce(!iszero, +, groups)
76+
if ngroups != length(groups)
77+
if combine === nothing
78+
combine = if Tv === Bool
79+
|
80+
else
81+
+
82+
end
83+
end
84+
85+
total_lengths = cumsum(groups) # TODO: add and use `scan!(; exclusive=true)`
86+
I = similar(I, ngroups)
87+
J = similar(J, ngroups)
88+
V = similar(V, ngroups)
89+
90+
# use one thread per value, and if it's at the start of a group,
91+
# combine (if needed) all values and update the output vectors.
92+
function combine_groups(groups, total_lengths, oldI, oldJ, oldV, newI, newJ, newV, combine)
93+
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x
94+
if i > length(groups)
95+
return
96+
end
97+
98+
# check if we're at the start of a group
99+
@inbounds if groups[i] != 0
100+
# get an exclusive offset from the inclusive cumsum
101+
offset = total_lengths[i] - groups[i] + 1
102+
103+
# copy values
104+
newI[i] = oldI[offset]
105+
newJ[i] = oldJ[offset]
106+
newV[i] = if groups[i] == 1
107+
oldV[offset]
108+
else
109+
# combine all values in the group
110+
val = oldV[offset]
111+
j = 1
112+
while j < groups[i]
113+
val = combine(val, oldV[offset+j])
114+
j += 1
115+
end
116+
val
117+
end
118+
end
119+
120+
return
121+
end
122+
kernel = @cuda launch=false combine_groups(groups, total_lengths, coo.rowInd, coo.colInd, coo.nzVal, I, J, V, combine)
123+
config = launch_configuration(kernel.fun)
124+
threads = min(length(groups), config.threads)
125+
blocks = cld(length(groups), threads)
126+
kernel(groups, total_lengths, coo.rowInd, coo.colInd, coo.nzVal, I, J, V, combine; threads, blocks)
127+
128+
coo = CuSparseMatrixCOO{Tv}(I, J, V, (m, n))
129+
end
130+
131+
if fmt == :coo
132+
return coo
133+
elseif fmt == :csc
41134
return CuSparseMatrixCSC(coo)
42135
elseif fmt == :csr
43136
return CuSparseMatrixCSR(coo)
44-
elseif fmt == :coo
45-
# The COO format is assumed to be sorted by row.
46-
return sort_coo(coo, 'R')
47137
else
48138
error("Format :$fmt not available, use :csc, :csr, or :coo.")
49139
end
@@ -231,7 +321,7 @@ for SparseMatrixType in (:CuSparseMatrixCSC, :CuSparseMatrixCSR, :CuSparseMatrix
231321
$SparseMatrixType(S::Diagonal{Tv, <:CuArray}) where Tv = $SparseMatrixType{Tv}(S)
232322
$SparseMatrixType{Tv}(S::Diagonal) where {Tv} = $SparseMatrixType{Tv, Cint}(S)
233323
end
234-
324+
235325
if SparseMatrixType == :CuSparseMatrixCOO
236326
@eval function $SparseMatrixType{Tv, Ti}(S::Diagonal) where {Tv, Ti}
237327
m = size(S, 1)
@@ -242,7 +332,7 @@ for SparseMatrixType in (:CuSparseMatrixCSC, :CuSparseMatrixCSR, :CuSparseMatrix
242332
m = size(S, 1)
243333
return $SparseMatrixType{Tv, Ti}(CuVector(1:(m+1)), CuVector(1:m), convert(CuVector{Tv}, S.diag), (m, m))
244334
end
245-
end
335+
end
246336
end
247337

248338
# by flipping rows and columns, we can use that to get CSC to CSR too

test/libraries/cusparse.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,3 +1078,27 @@ end
10781078
end
10791079
end
10801080
end
1081+
1082+
@testset "duplicate entries" begin
1083+
# already sorted
1084+
let
1085+
I = [1, 3, 4, 4]
1086+
J = [1, 2, 3, 3]
1087+
V = [1f0, 2f0, 3f0, 10f0]
1088+
coo = sparse(cu(I), cu(J), cu(V); fmt=:coo)
1089+
@test Array(coo.rowInd) == [1, 3, 4]
1090+
@test Array(coo.colInd) == [1, 2, 3]
1091+
@test Array(coo.nzVal) == [1f0, 2f0, 13f0]
1092+
end
1093+
1094+
# out of order
1095+
let
1096+
I = [4, 1, 3, 4]
1097+
J = [3, 1, 2, 3]
1098+
V = [10f0, 1f0, 2f0, 3f0]
1099+
coo = sparse(cu(I), cu(J), cu(V); fmt=:coo)
1100+
@test Array(coo.rowInd) == [1, 3, 4]
1101+
@test Array(coo.colInd) == [1, 2, 3]
1102+
@test Array(coo.nzVal) == [1f0, 2f0, 13f0]
1103+
end
1104+
end

0 commit comments

Comments
 (0)