Skip to content

Commit 991e23a

Browse files
authored
CUSPARSE: Fix sparse constructor with duplicate elements. (#2495)
1 parent 1b0bbc6 commit 991e23a

File tree

2 files changed

+30
-31
lines changed

2 files changed

+30
-31
lines changed

lib/cusparse/conversions.jl

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ function SparseArrays.sparse(I::CuVector{Cint}, J::CuVector{Cint}, V::CuVector{T
3939

4040
# find groups of values that correspond to the same position in the matrix.
4141
# if there's no duplicates, `groups` will just be a vector of ones.
42-
# otherwise, it will contain the number of duplicates for each group,
43-
# with subsequent values that are part of the group set to zero.
42+
# otherwise, it will contain gaps of zeros that indicates duplicate values.
4443
coo = sort_coo(coo, 'R')
4544
groups = similar(I, Int)
4645
function find_groups(groups, I, J)
@@ -51,14 +50,7 @@ function SparseArrays.sparse(I::CuVector{Cint}, J::CuVector{Cint}, V::CuVector{T
5150
len = 0
5251

5352
# check if we're at the start of a new group
54-
@inbounds if i == 1 || I[i] != I[i-1] || J[i] != J[i-1]
55-
len = 1
56-
while i+len <= length(groups) && I[i] == I[i+len] && J[i] == J[i+len]
57-
len += 1
58-
end
59-
end
60-
61-
@inbounds groups[i] = len
53+
@inbounds groups[i] = i == 1 || I[i] != I[i-1] || J[i] != J[i-1]
6254

6355
return
6456
end
@@ -82,49 +74,45 @@ function SparseArrays.sparse(I::CuVector{Cint}, J::CuVector{Cint}, V::CuVector{T
8274
end
8375
end
8476

85-
total_lengths = cumsum(groups) # TODO: add and use `scan!(; exclusive=true)`
77+
# by scanning the mask of groups, we can find a mapping for old to new indices
78+
indices = accumulate(+, groups)
79+
8680
I = similar(I, ngroups)
8781
J = similar(J, ngroups)
8882
V = similar(V, ngroups)
8983

90-
# use one thread per value, and if it's at the start of a group,
84+
# use one thread per (old) value, and if it's at the start of a group,
9185
# combine (if needed) all values and update the output vectors.
92-
function combine_groups(groups, total_lengths, oldI, oldJ, oldV, newI, newJ, newV, combine)
86+
function combine_groups(groups, indices, oldI, oldJ, oldV, newI, newJ, newV, combine)
9387
i = threadIdx().x + (blockIdx().x - 1) * blockDim().x
9488
if i > length(groups)
9589
return
9690
end
9791

9892
# check if we're at the start of a group
9993
@inbounds if groups[i] != 0
100-
# get an exclusive offset from the inclusive cumsum
101-
offset = total_lengths[i] - groups[i] + 1
94+
# get a destination index
95+
j = indices[i]
10296

10397
# copy values
104-
newI[i] = oldI[offset]
105-
newJ[i] = oldJ[offset]
106-
newV[i] = if groups[i] == 1
107-
oldV[offset]
108-
else
109-
# combine all values in the group
110-
val = oldV[offset]
111-
j = 1
112-
while j < groups[i]
113-
val = combine(val, oldV[offset+j])
114-
j += 1
115-
end
116-
val
98+
newI[j] = oldI[i]
99+
newJ[j] = oldJ[i]
100+
val = oldV[i]
101+
while i < length(groups) && groups[i+1] == 0
102+
i += 1
103+
val = combine(val, oldV[i])
117104
end
105+
newV[j] = val
118106
end
119107

120108
return
121109
end
122-
kernel = @cuda launch=false combine_groups(groups, total_lengths, coo.rowInd, coo.colInd, coo.nzVal, I, J, V, combine)
110+
kernel = @cuda launch=false combine_groups(groups, indices, coo.rowInd, coo.colInd, coo.nzVal, I, J, V, combine)
123111
config = launch_configuration(kernel.fun)
124112
threads = min(length(groups), config.threads)
125113
blocks = cld(length(groups), threads)
126-
kernel(groups, total_lengths, coo.rowInd, coo.colInd, coo.nzVal, I, J, V, combine; threads, blocks)
127-
114+
kernel(groups, indices, coo.rowInd, coo.colInd, coo.nzVal, I, J, V, combine; threads, blocks)
115+
synchronize()
128116
coo = CuSparseMatrixCOO{Tv}(I, J, V, (m, n))
129117
end
130118

test/libraries/cusparse.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,4 +1117,15 @@ end
11171117
@test Array(coo.colInd) == [1, 2, 3]
11181118
@test Array(coo.nzVal) == [1f0, 2f0, 13f0]
11191119
end
1120+
1121+
# JuliaGPU/CUDA.jl#2494
1122+
let
1123+
I = [1, 2, 1]
1124+
J = [1, 2, 1]
1125+
V = [10f0, 1f0, 2f0]
1126+
coo = sparse(cu(I), cu(J), cu(V); fmt=:coo)
1127+
@test Array(coo.rowInd) == [1, 2]
1128+
@test Array(coo.colInd) == [1, 2]
1129+
@test Array(coo.nzVal) == [12f0, 1f0]
1130+
end
11201131
end

0 commit comments

Comments
 (0)