Skip to content

Commit fb7440f

Browse files
authored
Add device counterparts for sparse arrays (#1106)
1 parent b04eb2f commit fb7440f

File tree

5 files changed

+177
-2
lines changed

5 files changed

+177
-2
lines changed

lib/cusparse/array.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,41 @@ function Base.show(io::IO, ::MIME"text/plain", S::AbstractCuSparseMatrix)
416416
show(IOContext(io, :typeinfo => eltype(S)), S)
417417
end
418418
end
419+
420+
421+
# interop with device arrays
422+
423+
Adapt.adapt_structure(to::CUDA.Adaptor, x::CuSparseVector{Tv}) where {Tv} =
424+
CuSparseDeviceVector{Tv,Cint}(
425+
adapt(to, x.iPtr),
426+
adapt(to, x.nzVal),
427+
x.dims, x.nnz)
428+
429+
Adapt.adapt_structure(to::CUDA.Adaptor, x::CuSparseMatrixCSR{Tv}) where {Tv} =
430+
CuSparseDeviceMatrixCSR{Tv,Cint}(
431+
adapt(to, x.rowPtr),
432+
adapt(to, x.colVal),
433+
adapt(to, x.nzVal),
434+
x.dims, x.nnz)
435+
436+
Adapt.adapt_structure(to::CUDA.Adaptor, x::CuSparseMatrixCSC{Tv}) where {Tv} =
437+
CuSparseDeviceMatrixCSC{Tv,Cint}(
438+
adapt(to, x.colPtr),
439+
adapt(to, x.rowVal),
440+
adapt(to, x.nzVal),
441+
x.dims, x.nnz)
442+
443+
Adapt.adapt_structure(to::CUDA.Adaptor, x::CuSparseMatrixBSR{Tv}) where {Tv} =
444+
CuSparseDeviceMatrixBSR{Tv,Cint}(
445+
adapt(to, x.rowPtr),
446+
adapt(to, x.colVal),
447+
adapt(to, x.nzVal),
448+
x.dims, x.blockDim,
449+
x.dir, x.nnz)
450+
451+
Adapt.adapt_structure(to::CUDA.Adaptor, x::CuSparseMatrixCOO{Tv}) where {Tv} =
452+
CuSparseDeviceMatrixCOO{Tv,Cint}(
453+
adapt(to, x.rowInd),
454+
adapt(to, x.colInd),
455+
adapt(to, x.nzVal),
456+
x.dims, x.nnz)

src/CUDA.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ include("device/intrinsics.jl")
6161
include("device/runtime.jl")
6262
include("device/texture.jl")
6363
include("device/random.jl")
64+
include("device/sparse.jl")
6465
include("device/quirks.jl")
6566

6667
# array essentials

src/device/array.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,9 @@ Base.@propagate_inbounds ldg(A::CuDeviceArray, i1::Integer) = const_arrayref(A,
225225
## other
226226

227227
Base.show(io::IO, a::CuDeviceVector) =
228-
print(io, "$(length(a))-element device array at $(pointer(a))")
228+
@printf(io, "%g-element device array at %p", length(a), Int(pointer(a)))
229229
Base.show(io::IO, a::CuDeviceArray) =
230-
print(io, "$(join(a.dims, '×')) device array at $(pointer(a))")
230+
@printf(io, "%s device array at %p", join(a.dims, '×'), Int(pointer(a)))
231231

232232
Base.show(io::IO, mime::MIME"text/plain", a::CuDeviceArray) = show(io, a)
233233

src/device/sparse.jl

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# on-device sparse array functionality
2+
3+
using SparseArrays
4+
5+
# NOTE: this functionality is currently very bare-bones, only defining the array types
6+
# without any device-compatible sparse array functionality
7+
8+
9+
# core types
10+
11+
export CuSparseDeviceVector, CuSparseDeviceMatrixCSC, CuSparseDeviceMatrixCSR,
12+
CuSparseDeviceMatrixBSR, CuSparseDeviceMatrixCOO
13+
14+
mutable struct CuSparseDeviceVector{Tv,Ti} <: AbstractSparseVector{Tv,Ti}
15+
iPtr::CuDeviceVector{Ti, AS.Global}
16+
nzVal::CuDeviceVector{Tv, AS.Global}
17+
dims::NTuple{2,Int}
18+
nnz::Int
19+
end
20+
21+
Base.length(g::CuSparseDeviceVector) = prod(g.dims)
22+
Base.size(g::CuSparseDeviceVector) = g.dims
23+
Base.ndims(g::CuSparseDeviceVector) = 1
24+
25+
mutable struct CuSparseDeviceMatrixCSC{Tv,Ti} <: AbstractSparseMatrix{Tv,Ti}
26+
colPtr::CuDeviceVector{Ti, AS.Global}
27+
rowVal::CuDeviceVector{Ti, AS.Global}
28+
nzVal::CuDeviceVector{Tv, AS.Global}
29+
dims::NTuple{2,Int}
30+
nnz::Int
31+
end
32+
33+
Base.length(g::CuSparseDeviceMatrixCSC) = prod(g.dims)
34+
Base.size(g::CuSparseDeviceMatrixCSC) = g.dims
35+
Base.ndims(g::CuSparseDeviceMatrixCSC) = 2
36+
37+
struct CuSparseDeviceMatrixCSR{Tv,Ti} <: AbstractSparseMatrix{Tv,Ti}
38+
rowPtr::CuDeviceVector{Ti, AS.Global}
39+
colVal::CuDeviceVector{Ti, AS.Global}
40+
nzVal::CuDeviceVector{Tv, AS.Global}
41+
dims::NTuple{2, Int}
42+
nnz::Int
43+
end
44+
45+
Base.length(g::CuSparseDeviceMatrixCSR) = prod(g.dims)
46+
Base.size(g::CuSparseDeviceMatrixCSR) = g.dims
47+
Base.ndims(g::CuSparseDeviceMatrixCSR) = 2
48+
49+
mutable struct CuSparseDeviceMatrixBSR{Tv,Ti} <: AbstractSparseMatrix{Tv,Ti}
50+
rowPtr::CuDeviceVector{Ti}
51+
colVal::CuDeviceVector{Ti}
52+
nzVal::CuDeviceVector{Tv}
53+
dims::NTuple{2,Int}
54+
blockDim::Int
55+
dir::Char
56+
nnz::Int
57+
end
58+
59+
Base.length(g::CuSparseDeviceMatrixBSR) = prod(g.dims)
60+
Base.size(g::CuSparseDeviceMatrixBSR) = g.dims
61+
Base.ndims(g::CuSparseDeviceMatrixBSR) = 2
62+
63+
struct CuSparseDeviceMatrixCOO{Tv,Ti} <: AbstractSparseMatrix{Tv,Ti}
64+
rowInd::CuDeviceVector{Ti}
65+
colInd::CuDeviceVector{Ti}
66+
nzVal::CuDeviceVector{Tv}
67+
dims::NTuple{2,Int}
68+
nnz::Int
69+
end
70+
71+
Base.length(g::CuSparseDeviceMatrixCOO) = prod(g.dims)
72+
Base.size(g::CuSparseDeviceMatrixCOO) = g.dims
73+
Base.ndims(g::CuSparseDeviceMatrixCOO) = 2
74+
75+
76+
# input/output
77+
78+
function Base.show(io::IO, ::MIME"text/plain", A::CuSparseDeviceVector)
79+
println(io, "$(length(A))-element device sparse vector at:")
80+
println(io, " iPtr: $(A.iPtr)")
81+
print(io, " nzVal: $(A.nzVal)")
82+
end
83+
84+
function Base.show(io::IO, ::MIME"text/plain", A::CuSparseDeviceMatrixCSR)
85+
println(io, "$(length(A))-element device sparse matrix CSR at:")
86+
println(io, " rowPtr: $(A.rowPtr)")
87+
println(io, " colVal: $(A.colVal)")
88+
print(io, " nzVal: $(A.nzVal)")
89+
end
90+
91+
function Base.show(io::IO, ::MIME"text/plain", A::CuSparseDeviceMatrixCSC)
92+
println(io, "$(length(A))-element device sparse matrix CSC at:")
93+
println(io, " colPtr: $(A.colPtr)")
94+
println(io, " rowVal: $(A.rowVal)")
95+
print(io, " nzVal: $(A.nzVal)")
96+
end
97+
98+
function Base.show(io::IO, ::MIME"text/plain", A::CuSparseDeviceMatrixBSR)
99+
println(io, "$(length(A))-element device sparse matrix BSR at:")
100+
println(io, " rowPtr: $(A.rowPtr)")
101+
println(io, " colVal: $(A.colVal)")
102+
print(io, " nzVal: $(A.nzVal)")
103+
end
104+
105+
function Base.show(io::IO, ::MIME"text/plain", A::CuSparseDeviceMatrixCOO)
106+
println(io, "$(length(A))-element device sparse matrix COO at:")
107+
println(io, " rowPtr: $(A.rowPtr)")
108+
println(io, " colInd: $(A.colInd)")
109+
print(io, " nzVal: $(A.nzVal)")
110+
end

test/device/sparse.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using Test
2+
using CUDA
3+
using CUDA.CUSPARSE
4+
using SparseArrays
5+
using CUDA: CuSparseDeviceVector, CuSparseDeviceMatrixCSC, CuSparseDeviceMatrixCSR,
6+
CuSparseDeviceMatrixBSR, CuSparseDeviceMatrixCOO
7+
8+
@testset "cudaconvert" begin
9+
V = sprand(10, 0.5)
10+
cuV = CuSparseVector(V)
11+
@test cudaconvert(cuV) isa CuSparseDeviceVector
12+
13+
A = sprand(10, 10, 0.5)
14+
cuA = CuSparseMatrixCSC(A)
15+
@test cudaconvert(cuA) isa CuSparseDeviceMatrixCSC
16+
17+
cuA = CuSparseMatrixCSR(A)
18+
@test cudaconvert(cuA) isa CuSparseDeviceMatrixCSR
19+
20+
cuA = CuSparseMatrixCOO(A)
21+
@test cudaconvert(cuA) isa CuSparseDeviceMatrixCOO
22+
23+
# Roger-Luo: I'm not sure how to create a BSR matrix
24+
# cuA = CuSparseMatrixBSR(A)
25+
# @test cudaconvert(cuA) isa CuSparseDeviceMatrixBSR
26+
end

0 commit comments

Comments
 (0)