3
3
4
4
export CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixBSR, CuSparseMatrixCOO,
5
5
CuSparseMatrix, AbstractCuSparseMatrix,
6
+ CuSparseArrayCSR,
6
7
CuSparseVector,
7
8
CuSparseVecOrMat
8
9
141
142
142
143
CuSparseMatrixCOO (A:: CuSparseMatrixCOO ) = A
143
144
145
+ mutable struct CuSparseArrayCSR{Tv, Ti, N} <: AbstractCuSparseArray{Tv, Ti, N}
146
+ rowPtr:: CuArray{Ti}
147
+ colVal:: CuArray{Ti}
148
+ nzVal:: CuArray{Tv}
149
+ dims:: NTuple{N,Int}
150
+ nnz:: Ti
151
+
152
+ function CuSparseArrayCSR {Tv, Ti, N} (rowPtr:: CuArray{<:Integer, M} , colVal:: CuArray{<:Integer, M} , nzVal:: CuArray{Tv, M} , dims:: NTuple{N,<:Integer} ) where {Tv, Ti<: Integer , M, N}
153
+ @assert M == N - 1 " CuSparseArrayCSR requires ndims(rowPtr) == ndims(colVal) == ndims(nzVal) == length(dims) - 1"
154
+ new {Tv, Ti, N} (rowPtr, colVal, nzVal, dims, length (nzVal))
155
+ end
156
+ end
157
+
158
+ CuSparseArrayCSR (A:: CuSparseArrayCSR ) = A
159
+
160
+ function CUDA. unsafe_free! (xs:: CuSparseArrayCSR )
161
+ unsafe_free! (xs. rowPtr)
162
+ unsafe_free! (xs. colVal)
163
+ unsafe_free! (nonzeros (xs))
164
+ return
165
+ end
166
+
167
+ # broadcast over batch-dim if batchsize==1
168
+ ptrstride (A:: CuSparseArrayCSR ) = size (A. rowPtr, 2 ) > 1 ? stride (A. rowPtr, 2 ) : 0
169
+ valstride (A:: CuSparseArrayCSR ) = size (A. nzVal, 2 ) > 1 ? stride (A. nzVal, 2 ) : 0
170
+
144
171
"""
145
172
Utility union type of [`CuSparseMatrixCSC`](@ref), [`CuSparseMatrixCSR`](@ref),
146
173
[`CuSparseMatrixBSR`](@ref), [`CuSparseMatrixCOO`](@ref).
@@ -154,7 +181,6 @@ const CuSparseMatrix{Tv, Ti} = Union{
154
181
155
182
const CuSparseVecOrMat = Union{CuSparseVector,CuSparseMatrix}
156
183
157
-
158
184
# NOTE: we use Cint as default Ti on CUDA instead of Int to provide
159
185
# maximum compatiblity to old CUSPARSE APIs
160
186
function CuSparseVector {Tv} (iPtr:: CuVector{<:Integer} , nzVal:: CuVector , len:: Integer ) where {Tv}
@@ -183,6 +209,11 @@ function CuSparseMatrixCOO{Tv}(rowInd::CuVector{<:Integer}, colInd::CuVector{<:I
183
209
CuSparseMatrixCOO {Tv, Cint} (rowInd,colInd,nzVal,dims,nnz)
184
210
end
185
211
212
+ function CuSparseArrayCSR {Tv} (rowPtr:: CuArray{<:Integer, M} , colVal:: CuArray{<:Integer, M} ,
213
+ nzVal:: CuArray{Tv, M} , dims:: NTuple{N,<:Integer} ) where {Tv, M, N}
214
+ CuSparseArrayCSR {Tv, Cint, N} (rowPtr, colVal, nzVal, dims)
215
+ end
216
+
186
217
# # convenience constructors
187
218
CuSparseVector (iPtr:: DenseCuArray{<:Integer} , nzVal:: DenseCuArray{T} , len:: Integer ) where {T} =
188
219
CuSparseVector {T} (iPtr, nzVal, len)
@@ -201,6 +232,9 @@ CuSparseMatrixBSR(rowPtr::DenseCuArray, colVal::DenseCuArray, nzVal::DenseCuArra
201
232
CuSparseMatrixCOO (rowInd:: DenseCuArray , colInd:: DenseCuArray , nzVal:: DenseCuArray{T} , dims:: NTuple{2,<:Integer} , nnz:: Integer = length (nzVal)) where T =
202
233
CuSparseMatrixCOO {T} (rowInd, colInd, nzVal, dims, nnz)
203
234
235
+ CuSparseArrayCSR (rowPtr:: DenseCuArray , colVal:: DenseCuArray , nzVal:: DenseCuArray{T} , dims:: NTuple{N,<:Integer} ) where {T,N} =
236
+ CuSparseArrayCSR {T} (rowPtr, colVal, nzVal, dims)
237
+
204
238
Base. similar (Vec:: CuSparseVector ) = CuSparseVector (copy (nonzeroinds (Vec)), similar (nonzeros (Vec)), length (Vec))
205
239
Base. similar (Mat:: CuSparseMatrixCSC ) = CuSparseMatrixCSC (copy (Mat. colPtr), copy (rowvals (Mat)), similar (nonzeros (Mat)), size (Mat))
206
240
Base. similar (Mat:: CuSparseMatrixCSR ) = CuSparseMatrixCSR (copy (Mat. rowPtr), copy (Mat. colVal), similar (nonzeros (Mat)), size (Mat))
@@ -216,6 +250,7 @@ Base.similar(Mat::CuSparseMatrixCOO, T::Type) = CuSparseMatrixCOO(copy(Mat.rowIn
216
250
Base. similar (Mat:: CuSparseMatrixCSC , T:: Type , N:: Int , M:: Int ) = CuSparseMatrixCSC (CuVector {Int32} (undef, M+ 1 ), CuVector {Int32} (undef, nnz (Mat)), CuVector {T} (undef, nnz (Mat)), (N,M))
217
251
Base. similar (Mat:: CuSparseMatrixCSR , T:: Type , N:: Int , M:: Int ) = CuSparseMatrixCSR (CuVector {Int32} (undef, N+ 1 ), CuVector {Int32} (undef, nnz (Mat)), CuVector {T} (undef, nnz (Mat)), (N,M))
218
252
Base. similar (Mat:: CuSparseMatrixCOO , T:: Type , N:: Int , M:: Int ) = CuSparseMatrixCOO (CuVector {Int32} (undef, nnz (Mat)), CuVector {Int32} (undef, nnz (Mat)), CuVector {T} (undef, nnz (Mat)), (N,M))
253
+ Base. similar (Mat:: CuSparseArrayCSR ) = CuSparseArrayCSR (copy (Mat. rowPtr), copy (Mat. colVal), similar (nonzeros (Mat)), size (Mat))
219
254
220
255
# # array interface
221
256
@@ -225,6 +260,9 @@ Base.size(g::CuSparseVector) = (g.len,)
225
260
Base. length (g:: CuSparseMatrix ) = prod (g. dims)
226
261
Base. size (g:: CuSparseMatrix ) = g. dims
227
262
263
+ Base. length (g:: CuSparseArrayCSR ) = prod (g. dims)
264
+ Base. size (g:: CuSparseArrayCSR ) = g. dims
265
+
228
266
function Base. size (g:: CuSparseVector , d:: Integer )
229
267
if d == 1
230
268
return g. len
@@ -245,6 +283,15 @@ function Base.size(g::CuSparseMatrix, d::Integer)
245
283
end
246
284
end
247
285
286
+ function Base. size (g:: CuSparseArrayCSR{Tv,Ti,N} , d:: Integer ) where {Tv,Ti,N}
287
+ if 1 <= d <= N
288
+ return g. dims[d]
289
+ elseif d > 1
290
+ return 1
291
+ else
292
+ throw (ArgumentError (" dimension must be ≥ 1, got $d " ))
293
+ end
294
+ end
248
295
249
296
# # sparse array interface
250
297
@@ -348,6 +395,16 @@ function Base.getindex(A::CuSparseMatrixBSR{T}, i0::Integer, i1::Integer) where
348
395
nonzeros (A)[c1+ block_idx]
349
396
end
350
397
398
+ # matrix slices
399
+ function Base. getindex (A:: CuSparseArrayCSR{Tv, Ti, N} , :: Colon , :: Colon , idxs:: Integer... ) where {Tv, Ti, N}
400
+ @boundscheck checkbounds (A, :, :, idxs... )
401
+ CuSparseMatrixCSR (A. rowPtr[:,idxs... ], A. colVal[:,idxs... ], nonzeros (A)[:,idxs... ], size (A)[1 : 2 ])
402
+ end
403
+
404
+ function Base. getindex (A:: CuSparseArrayCSR{Tv, Ti, N} , i0:: Integer , i1:: Integer , idxs:: Integer... ) where {Tv, Ti, N}
405
+ @boundscheck checkbounds (A, i0, i1, idxs... )
406
+ CuSparseMatrixCSR (A. rowPtr[:,idxs... ], A. colVal[:,idxs... ], nonzeros (A)[:,idxs... ], size (A)[1 : 2 ])[i0, i1]
407
+ end
351
408
352
409
# # interop with sparse CPU arrays
353
410
@@ -502,7 +559,7 @@ Base.copy(Mat::CuSparseMatrixCSC) = copyto!(similar(Mat), Mat)
502
559
Base. copy (Mat:: CuSparseMatrixCSR ) = copyto! (similar (Mat), Mat)
503
560
Base. copy (Mat:: CuSparseMatrixBSR ) = copyto! (similar (Mat), Mat)
504
561
Base. copy (Mat:: CuSparseMatrixCOO ) = copyto! (similar (Mat), Mat)
505
-
562
+ Base . copy (Mat :: CuSparseArrayCSR ) = CuSparseArrayCSR ( copy (Mat . rowPtr), copy (Mat . colVal), copy ( nonzeros (Mat)), size (Mat))
506
563
507
564
# input/output
508
565
@@ -543,6 +600,24 @@ for (gpu, cpu) in [:CuSparseMatrixCSC => :SparseMatrixCSC,
543
600
end
544
601
end
545
602
603
+ function Base. show (io:: IOContext , :: MIME"text/plain" , A:: CuSparseArrayCSR )
604
+ xnnz = nnz (A)
605
+ dims = join (size (A), " ×" )
606
+
607
+ print (io, dims... , " " , typeof (A), " with " , xnnz, " stored " , xnnz == 1 ? " entry" : " entries" )
608
+
609
+ if all (size (A) .> 0 )
610
+ println (io, " :" )
611
+ io = IOContext (io, :typeinfo => eltype (A))
612
+ for (k, c) in enumerate (CartesianIndices (size (A)[3 : end ]))
613
+ k > 1 && println (io, " \n " )
614
+ dims = join (c. I, " , " )
615
+ println (io, " [:, :, $dims ] =" )
616
+ Base. print_array (io, SparseMatrixCSC (A[:,:,c. I... ]))
617
+ end
618
+ end
619
+ end
620
+
546
621
547
622
# interop with device arrays
548
623
@@ -590,3 +665,13 @@ function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseMatrixCOO)
590
665
size (x), x. nnz
591
666
)
592
667
end
668
+
669
+ function Adapt. adapt_structure (to:: CUDA.KernelAdaptor , x:: CuSparseArrayCSR )
670
+ return CuSparseDeviceArrayCSR (
671
+ adapt (to, x. rowPtr),
672
+ adapt (to, x. colVal),
673
+ adapt (to, x. nzVal),
674
+ size (x), x. nnz
675
+ )
676
+ end
677
+
0 commit comments