@@ -7,14 +7,15 @@ export CuSparseMatrixCSC, CuSparseMatrixCSR,
7
7
CuSparseVector
8
8
9
9
import Base: length, size, ndims, eltype, similar, pointer, stride,
10
- copy, convert, reinterpret, show, summary, copyto!, get!, fill!, collect
10
+ copy, convert, reinterpret, show, summary, copyto!, getindex, get!, fill!, collect
11
11
12
12
using LinearAlgebra
13
13
import LinearAlgebra: BlasFloat, Hermitian, HermOrSym, issymmetric, Transpose, Adjoint,
14
14
ishermitian, istriu, istril, Symmetric, UpperTriangular, LowerTriangular
15
15
16
16
using SparseArrays
17
- import SparseArrays: sparse, SparseMatrixCSC
17
+ import SparseArrays: sparse, SparseMatrixCSC, nnz, nonzeros, nonzeroinds,
18
+ _spgetindex
18
19
19
20
abstract type AbstractCuSparseArray{Tv, N} <: AbstractSparseArray{Tv, Cint, N} end
20
21
const AbstractCuSparseVector{Tv} = AbstractCuSparseArray{Tv,1 }
@@ -166,6 +167,11 @@ function size(g::CuSparseMatrix, d::Integer)
166
167
end
167
168
end
168
169
170
+ nnz (g:: AbstractCuSparseArray ) = g. nnz
171
+ nonzeros (g:: AbstractCuSparseArray ) = g. nzVal
172
+
173
+ nonzeroinds (g:: AbstractCuSparseVector ) = g. iPtr
174
+
169
175
issymmetric (M:: Union{CuSparseMatrixCSC,CuSparseMatrixCSR} ) = false
170
176
ishermitian (M:: Union{CuSparseMatrixCSC,CuSparseMatrixCSR} ) = false
171
177
issymmetric (M:: Symmetric{CuSparseMatrixCSC} ) = true
@@ -177,6 +183,67 @@ istriu(M::LowerTriangular{T,S}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix}
177
183
istril (M:: LowerTriangular{T,S} ) where {T<: BlasFloat , S<: AbstractCuSparseMatrix } = true
178
184
eltype (g:: CuSparseMatrix{T} ) where T = T
179
185
186
+ # getindex (mostly adapted from stdlib/SparseArrays)
187
+
188
+ # Translations
189
+ getindex (A:: AbstractCuSparseVector , :: Colon ) = copy (A)
190
+ getindex (A:: AbstractCuSparseMatrix , :: Colon , :: Colon ) = copy (A)
191
+ getindex (A:: AbstractCuSparseMatrix , i, :: Colon ) = getindex (A, i, 1 : size (A, 2 ))
192
+ getindex (A:: AbstractCuSparseMatrix , :: Colon , i) = getindex (A, 1 : size (A, 1 ), i)
193
+ getindex (A:: AbstractCuSparseMatrix , I:: Tuple{Integer,Integer} ) = getindex (A, I[1 ], I[2 ])
194
+
195
+ # Column slices
196
+ function getindex (x:: CuSparseMatrixCSC , :: Colon , j:: Integer )
197
+ checkbounds (x, :, j)
198
+ r1 = convert (Int, x. colPtr[j])
199
+ r2 = convert (Int, x. colPtr[j+ 1 ]) - 1
200
+ CuSparseVector (x. rowVal[r1: r2], x. nzVal[r1: r2], size (x, 1 ))
201
+ end
202
+
203
+ function getindex (x:: CuSparseMatrixCSR , i:: Integer , :: Colon )
204
+ checkbounds (x, :, i)
205
+ c1 = convert (Int, x. rowPtr[i])
206
+ c2 = convert (Int, x. rowPtr[i+ 1 ]) - 1
207
+ CuSparseVector (x. colVal[c1: c2], x. nzVal[c1: c2], size (x, 2 ))
208
+ end
209
+
210
+ # Row slices
211
+ # TODO optimize
212
+ getindex (A:: CuSparseMatrixCSC , i:: Integer , :: Colon ) = CuSparseVector (sparse (A[i, 1 : end ]))
213
+ # TODO optimize
214
+ getindex (A:: CuSparseMatrixCSR , :: Colon , j:: Integer ) = CuSparseVector (sparse (A[1 : end , j]))
215
+
216
+ function getindex (A:: CuSparseMatrixCSC{T} , i0:: Integer , i1:: Integer ) where T
217
+ m, n = size (A)
218
+ if ! (1 <= i0 <= m && 1 <= i1 <= n)
219
+ throw (BoundsError ())
220
+ end
221
+ r1 = Int (A. colPtr[i1])
222
+ r2 = Int (A. colPtr[i1+ 1 ]- 1 )
223
+ (r1 > r2) && return zero (T)
224
+ r1 = searchsortedfirst (A. rowVal, i0, r1, r2, Base. Order. Forward)
225
+ ((r1 > r2) || (A. rowVal[r1] != i0)) ? zero (T) : A. nzVal[r1]
226
+ end
227
+
228
+ function getindex (A:: CuSparseMatrixCSR{T} , i0:: Integer , i1:: Integer ) where T
229
+ m, n = size (A)
230
+ if ! (1 <= i0 <= m && 1 <= i1 <= n)
231
+ throw (BoundsError ())
232
+ end
233
+ c1 = Int (A. rowPtr[i0])
234
+ c2 = Int (A. rowPtr[i0+ 1 ]- 1 )
235
+ (c1 > c2) && return zero (T)
236
+ c1 = searchsortedfirst (A. colVal, i1, c1, c2, Base. Order. Forward)
237
+ ((c1 > c2) || (A. colVal[c1] != i1)) ? zero (T) : A. nzVal[c1]
238
+ end
239
+
240
+ # Called for indexing into `CuSparseVector`s
241
+ function _spgetindex (m:: Integer , nzind:: CuVector{Ti} , nzval:: CuVector{Tv} ,
242
+ i:: Integer ) where {Tv,Ti}
243
+ ii = searchsortedfirst (nzind, convert (Ti, i))
244
+ (ii <= m && nzind[ii] == i) ? nzval[ii] : zero (Tv)
245
+ end
246
+
180
247
function collect (Vec:: CuSparseVector )
181
248
SparseVector (Vec. dims[1 ], collect (Vec. iPtr), collect (Vec. nzVal))
182
249
end
0 commit comments