@@ -39,8 +39,7 @@ function SparseArrays.sparse(I::CuVector{Cint}, J::CuVector{Cint}, V::CuVector{T
39
39
40
40
# find groups of values that correspond to the same position in the matrix.
41
41
# if there's no duplicates, `groups` will just be a vector of ones.
42
- # otherwise, it will contain the number of duplicates for each group,
43
- # with subsequent values that are part of the group set to zero.
42
+ # otherwise, it will contain gaps of zeros that indicates duplicate values.
44
43
coo = sort_coo (coo, ' R' )
45
44
groups = similar (I, Int)
46
45
function find_groups (groups, I, J)
@@ -51,14 +50,7 @@ function SparseArrays.sparse(I::CuVector{Cint}, J::CuVector{Cint}, V::CuVector{T
51
50
len = 0
52
51
53
52
# check if we're at the start of a new group
54
- @inbounds if i == 1 || I[i] != I[i- 1 ] || J[i] != J[i- 1 ]
55
- len = 1
56
- while i+ len <= length (groups) && I[i] == I[i+ len] && J[i] == J[i+ len]
57
- len += 1
58
- end
59
- end
60
-
61
- @inbounds groups[i] = len
53
+ @inbounds groups[i] = i == 1 || I[i] != I[i- 1 ] || J[i] != J[i- 1 ]
62
54
63
55
return
64
56
end
@@ -82,49 +74,45 @@ function SparseArrays.sparse(I::CuVector{Cint}, J::CuVector{Cint}, V::CuVector{T
82
74
end
83
75
end
84
76
85
- total_lengths = cumsum (groups) # TODO : add and use `scan!(; exclusive=true)`
77
+ # by scanning the mask of groups, we can find a mapping for old to new indices
78
+ indices = accumulate (+ , groups)
79
+
86
80
I = similar (I, ngroups)
87
81
J = similar (J, ngroups)
88
82
V = similar (V, ngroups)
89
83
90
- # use one thread per value, and if it's at the start of a group,
84
+ # use one thread per (old) value, and if it's at the start of a group,
91
85
# combine (if needed) all values and update the output vectors.
92
- function combine_groups (groups, total_lengths , oldI, oldJ, oldV, newI, newJ, newV, combine)
86
+ function combine_groups (groups, indices , oldI, oldJ, oldV, newI, newJ, newV, combine)
93
87
i = threadIdx (). x + (blockIdx (). x - 1 ) * blockDim (). x
94
88
if i > length (groups)
95
89
return
96
90
end
97
91
98
92
# check if we're at the start of a group
99
93
@inbounds if groups[i] != 0
100
- # get an exclusive offset from the inclusive cumsum
101
- offset = total_lengths [i] - groups[i] + 1
94
+ # get a destination index
95
+ j = indices [i]
102
96
103
97
# copy values
104
- newI[i] = oldI[offset]
105
- newJ[i] = oldJ[offset]
106
- newV[i] = if groups[i] == 1
107
- oldV[offset]
108
- else
109
- # combine all values in the group
110
- val = oldV[offset]
111
- j = 1
112
- while j < groups[i]
113
- val = combine (val, oldV[offset+ j])
114
- j += 1
115
- end
116
- val
98
+ newI[j] = oldI[i]
99
+ newJ[j] = oldJ[i]
100
+ val = oldV[i]
101
+ while i < length (groups) && groups[i+ 1 ] == 0
102
+ i += 1
103
+ val = combine (val, oldV[i])
117
104
end
105
+ newV[j] = val
118
106
end
119
107
120
108
return
121
109
end
122
- kernel = @cuda launch= false combine_groups (groups, total_lengths , coo. rowInd, coo. colInd, coo. nzVal, I, J, V, combine)
110
+ kernel = @cuda launch= false combine_groups (groups, indices , coo. rowInd, coo. colInd, coo. nzVal, I, J, V, combine)
123
111
config = launch_configuration (kernel. fun)
124
112
threads = min (length (groups), config. threads)
125
113
blocks = cld (length (groups), threads)
126
- kernel (groups, total_lengths , coo. rowInd, coo. colInd, coo. nzVal, I, J, V, combine; threads, blocks)
127
-
114
+ kernel (groups, indices , coo. rowInd, coo. colInd, coo. nzVal, I, J, V, combine; threads, blocks)
115
+ synchronize ()
128
116
coo = CuSparseMatrixCOO {Tv} (I, J, V, (m, n))
129
117
end
130
118
0 commit comments