Skip to content

Commit 88ebe50

Browse files
nikopjNikola
andauthored
Add CuSparseArrayCSR (N dim array) for batched matmatmul (bmm) (#1944)
Co-authored-by: Nikola <npj226@nyu.edu>
1 parent a152355 commit 88ebe50

File tree

8 files changed

+425
-2
lines changed

8 files changed

+425
-2
lines changed

lib/cusparse/CUSPARSE.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ include("util.jl")
3232
include("types.jl")
3333
include("linalg.jl")
3434

35+
3536
# low-level wrappers
3637
include("helpers.jl")
3738
include("management.jl")
@@ -51,6 +52,8 @@ include("device.jl")
5152
include("broadcast.jl")
5253
include("reduce.jl")
5354

55+
include("batched.jl")
56+
5457
# cache for created, but unused handles
5558
const idle_handles = HandleCache{CuContext,cusparseHandle_t}()
5659

lib/cusparse/array.jl

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
export CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixBSR, CuSparseMatrixCOO,
55
CuSparseMatrix, AbstractCuSparseMatrix,
6+
CuSparseArrayCSR,
67
CuSparseVector,
78
CuSparseVecOrMat
89

@@ -141,6 +142,32 @@ end
141142

142143
CuSparseMatrixCOO(A::CuSparseMatrixCOO) = A
143144

145+
mutable struct CuSparseArrayCSR{Tv, Ti, N} <: AbstractCuSparseArray{Tv, Ti, N}
146+
rowPtr::CuArray{Ti}
147+
colVal::CuArray{Ti}
148+
nzVal::CuArray{Tv}
149+
dims::NTuple{N,Int}
150+
nnz::Ti
151+
152+
function CuSparseArrayCSR{Tv, Ti, N}(rowPtr::CuArray{<:Integer, M}, colVal::CuArray{<:Integer, M}, nzVal::CuArray{Tv, M}, dims::NTuple{N,<:Integer}) where {Tv, Ti<:Integer, M, N}
153+
@assert M == N - 1 "CuSparseArrayCSR requires ndims(rowPtr) == ndims(colVal) == ndims(nzVal) == length(dims) - 1"
154+
new{Tv, Ti, N}(rowPtr, colVal, nzVal, dims, length(nzVal))
155+
end
156+
end
157+
158+
CuSparseArrayCSR(A::CuSparseArrayCSR) = A
159+
160+
function CUDA.unsafe_free!(xs::CuSparseArrayCSR)
161+
unsafe_free!(xs.rowPtr)
162+
unsafe_free!(xs.colVal)
163+
unsafe_free!(nonzeros(xs))
164+
return
165+
end
166+
167+
# broadcast over batch-dim if batchsize==1
168+
ptrstride(A::CuSparseArrayCSR) = size(A.rowPtr, 2) > 1 ? stride(A.rowPtr, 2) : 0
169+
valstride(A::CuSparseArrayCSR) = size(A.nzVal, 2) > 1 ? stride(A.nzVal, 2) : 0
170+
144171
"""
145172
Utility union type of [`CuSparseMatrixCSC`](@ref), [`CuSparseMatrixCSR`](@ref),
146173
[`CuSparseMatrixBSR`](@ref), [`CuSparseMatrixCOO`](@ref).
@@ -154,7 +181,6 @@ const CuSparseMatrix{Tv, Ti} = Union{
154181

155182
const CuSparseVecOrMat = Union{CuSparseVector,CuSparseMatrix}
156183

157-
158184
# NOTE: we use Cint as default Ti on CUDA instead of Int to provide
159185
# maximum compatiblity to old CUSPARSE APIs
160186
function CuSparseVector{Tv}(iPtr::CuVector{<:Integer}, nzVal::CuVector, len::Integer) where {Tv}
@@ -183,6 +209,11 @@ function CuSparseMatrixCOO{Tv}(rowInd::CuVector{<:Integer}, colInd::CuVector{<:I
183209
CuSparseMatrixCOO{Tv, Cint}(rowInd,colInd,nzVal,dims,nnz)
184210
end
185211

212+
function CuSparseArrayCSR{Tv}(rowPtr::CuArray{<:Integer, M}, colVal::CuArray{<:Integer, M},
213+
nzVal::CuArray{Tv, M}, dims::NTuple{N,<:Integer}) where {Tv, M, N}
214+
CuSparseArrayCSR{Tv, Cint, N}(rowPtr, colVal, nzVal, dims)
215+
end
216+
186217
## convenience constructors
187218
CuSparseVector(iPtr::DenseCuArray{<:Integer}, nzVal::DenseCuArray{T}, len::Integer) where {T} =
188219
CuSparseVector{T}(iPtr, nzVal, len)
@@ -201,6 +232,9 @@ CuSparseMatrixBSR(rowPtr::DenseCuArray, colVal::DenseCuArray, nzVal::DenseCuArra
201232
CuSparseMatrixCOO(rowInd::DenseCuArray, colInd::DenseCuArray, nzVal::DenseCuArray{T}, dims::NTuple{2,<:Integer}, nnz::Integer=length(nzVal)) where T =
202233
CuSparseMatrixCOO{T}(rowInd, colInd, nzVal, dims, nnz)
203234

235+
CuSparseArrayCSR(rowPtr::DenseCuArray, colVal::DenseCuArray, nzVal::DenseCuArray{T}, dims::NTuple{N,<:Integer}) where {T,N} =
236+
CuSparseArrayCSR{T}(rowPtr, colVal, nzVal, dims)
237+
204238
Base.similar(Vec::CuSparseVector) = CuSparseVector(copy(nonzeroinds(Vec)), similar(nonzeros(Vec)), length(Vec))
205239
Base.similar(Mat::CuSparseMatrixCSC) = CuSparseMatrixCSC(copy(Mat.colPtr), copy(rowvals(Mat)), similar(nonzeros(Mat)), size(Mat))
206240
Base.similar(Mat::CuSparseMatrixCSR) = CuSparseMatrixCSR(copy(Mat.rowPtr), copy(Mat.colVal), similar(nonzeros(Mat)), size(Mat))
@@ -216,6 +250,7 @@ Base.similar(Mat::CuSparseMatrixCOO, T::Type) = CuSparseMatrixCOO(copy(Mat.rowIn
216250
Base.similar(Mat::CuSparseMatrixCSC, T::Type, N::Int, M::Int) = CuSparseMatrixCSC(CuVector{Int32}(undef, M+1), CuVector{Int32}(undef, nnz(Mat)), CuVector{T}(undef, nnz(Mat)), (N,M))
217251
Base.similar(Mat::CuSparseMatrixCSR, T::Type, N::Int, M::Int) = CuSparseMatrixCSR(CuVector{Int32}(undef, N+1), CuVector{Int32}(undef, nnz(Mat)), CuVector{T}(undef, nnz(Mat)), (N,M))
218252
Base.similar(Mat::CuSparseMatrixCOO, T::Type, N::Int, M::Int) = CuSparseMatrixCOO(CuVector{Int32}(undef, nnz(Mat)), CuVector{Int32}(undef, nnz(Mat)), CuVector{T}(undef, nnz(Mat)), (N,M))
253+
Base.similar(Mat::CuSparseArrayCSR) = CuSparseArrayCSR(copy(Mat.rowPtr), copy(Mat.colVal), similar(nonzeros(Mat)), size(Mat))
219254

220255
## array interface
221256

@@ -225,6 +260,9 @@ Base.size(g::CuSparseVector) = (g.len,)
225260
Base.length(g::CuSparseMatrix) = prod(g.dims)
226261
Base.size(g::CuSparseMatrix) = g.dims
227262

263+
Base.length(g::CuSparseArrayCSR) = prod(g.dims)
264+
Base.size(g::CuSparseArrayCSR) = g.dims
265+
228266
function Base.size(g::CuSparseVector, d::Integer)
229267
if d == 1
230268
return g.len
@@ -245,6 +283,15 @@ function Base.size(g::CuSparseMatrix, d::Integer)
245283
end
246284
end
247285

286+
function Base.size(g::CuSparseArrayCSR{Tv,Ti,N}, d::Integer) where {Tv,Ti,N}
287+
if 1 <= d <= N
288+
return g.dims[d]
289+
elseif d > 1
290+
return 1
291+
else
292+
throw(ArgumentError("dimension must be ≥ 1, got $d"))
293+
end
294+
end
248295

249296
## sparse array interface
250297

@@ -348,6 +395,16 @@ function Base.getindex(A::CuSparseMatrixBSR{T}, i0::Integer, i1::Integer) where
348395
nonzeros(A)[c1+block_idx]
349396
end
350397

398+
# matrix slices
399+
function Base.getindex(A::CuSparseArrayCSR{Tv, Ti, N}, ::Colon, ::Colon, idxs::Integer...) where {Tv, Ti, N}
400+
@boundscheck checkbounds(A, :, :, idxs...)
401+
CuSparseMatrixCSR(A.rowPtr[:,idxs...], A.colVal[:,idxs...], nonzeros(A)[:,idxs...], size(A)[1:2])
402+
end
403+
404+
function Base.getindex(A::CuSparseArrayCSR{Tv, Ti, N}, i0::Integer, i1::Integer, idxs::Integer...) where {Tv, Ti, N}
405+
@boundscheck checkbounds(A, i0, i1, idxs...)
406+
CuSparseMatrixCSR(A.rowPtr[:,idxs...], A.colVal[:,idxs...], nonzeros(A)[:,idxs...], size(A)[1:2])[i0, i1]
407+
end
351408

352409
## interop with sparse CPU arrays
353410

@@ -502,7 +559,7 @@ Base.copy(Mat::CuSparseMatrixCSC) = copyto!(similar(Mat), Mat)
502559
Base.copy(Mat::CuSparseMatrixCSR) = copyto!(similar(Mat), Mat)
503560
Base.copy(Mat::CuSparseMatrixBSR) = copyto!(similar(Mat), Mat)
504561
Base.copy(Mat::CuSparseMatrixCOO) = copyto!(similar(Mat), Mat)
505-
562+
Base.copy(Mat::CuSparseArrayCSR) = CuSparseArrayCSR(copy(Mat.rowPtr), copy(Mat.colVal), copy(nonzeros(Mat)), size(Mat))
506563

507564
# input/output
508565

@@ -543,6 +600,24 @@ for (gpu, cpu) in [:CuSparseMatrixCSC => :SparseMatrixCSC,
543600
end
544601
end
545602

603+
function Base.show(io::IOContext, ::MIME"text/plain", A::CuSparseArrayCSR)
604+
xnnz = nnz(A)
605+
dims = join(size(A), "×")
606+
607+
print(io, dims..., " ", typeof(A), " with ", xnnz, " stored ", xnnz == 1 ? "entry" : "entries")
608+
609+
if all(size(A) .> 0)
610+
println(io, ":")
611+
io = IOContext(io, :typeinfo => eltype(A))
612+
for (k, c) in enumerate(CartesianIndices(size(A)[3:end]))
613+
k > 1 && println(io, "\n")
614+
dims = join(c.I, ", ")
615+
println(io, "[:, :, $dims] =")
616+
Base.print_array(io, SparseMatrixCSC(A[:,:,c.I...]))
617+
end
618+
end
619+
end
620+
546621

547622
# interop with device arrays
548623

@@ -590,3 +665,13 @@ function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseMatrixCOO)
590665
size(x), x.nnz
591666
)
592667
end
668+
669+
function Adapt.adapt_structure(to::CUDA.KernelAdaptor, x::CuSparseArrayCSR)
670+
return CuSparseDeviceArrayCSR(
671+
adapt(to, x.rowPtr),
672+
adapt(to, x.colVal),
673+
adapt(to, x.nzVal),
674+
size(x), x.nnz
675+
)
676+
end
677+

lib/cusparse/batched.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
function Base.cat(As::CuSparseMatrixCSR...; dims=3)
2+
if dims == 1
3+
return hcat(As)
4+
elseif dims == 2
5+
return vcat(As)
6+
end
7+
newsize = (size(As[1])..., ones(Int, dims-3)..., length(As))
8+
CuSparseArrayCSR(cat([A.rowPtr for A in As]...; dims=dims-1),
9+
cat([A.colVal for A in As]...; dims=dims-1),
10+
cat([A.nzVal for A in As]...; dims=dims-1),
11+
newsize)
12+
end
13+
14+
function Base.cat(As::CuSparseArrayCSR...; dims=3)
15+
if dims == 1
16+
return hcat(As)
17+
elseif dims == 2
18+
return vcat(As)
19+
end
20+
rowPtr = cat([A.rowPtr for A in As]...; dims=dims-1)
21+
CuSparseArrayCSR(rowPtr,
22+
cat([A.colVal for A in As]...; dims=dims-1),
23+
cat([A.nzVal for A in As]...; dims=dims-1),
24+
(size(As[1])[1:2]..., size(rowPtr)[2:end]...))
25+
end
26+
27+
# we can't reshape the first two dimensions
28+
function Base.reshape(A::Union{CuSparseArrayCSR, CuSparseMatrixCSR}, ::Colon, ::Colon, bshape::Int64...)
29+
CuSparseArrayCSR(reshape(A.rowPtr, :, bshape...),
30+
reshape(A.colVal, :, bshape...),
31+
reshape(A.nzVal, :, bshape...),
32+
(size(A)[1:2]..., bshape...))
33+
end
34+
35+
function Base.reshape(A::CuSparseArrayCSR, dims::Int64...)
36+
s1, s2, bshape = dims[1], dims[2], dims[3:end]
37+
@assert s1 == size(A, 1) && s2 == size(A, 2)
38+
CuSparseArrayCSR(reshape(A.rowPtr, :, bshape...),
39+
reshape(A.colVal, :, bshape...),
40+
reshape(A.nzVal, :, bshape...),
41+
(size(A)[1:2]..., bshape...))
42+
end
43+
44+
# reshape to have a single batch dimension
45+
function Base.reshape(A::CuSparseArrayCSR, ::Colon, ::Colon, ::Colon)
46+
b = prod(size(A)[3:end])
47+
CuSparseArrayCSR(reshape(A.rowPtr, :, b),
48+
reshape(A.colVal, :, b),
49+
reshape(A.nzVal, :, b),
50+
(size(A)[1:2]..., b))
51+
end
52+
53+
# repeat non-matrix dimensions
54+
function Base.repeat(A::Union{CuSparseArrayCSR, CuSparseMatrixCSR}, r1::Int64, r2::Int64, rs::Int64...)
55+
@assert r1 == 1 && r2 == 1 "Cannot repeat matrix dimensions of CuSparseCSR"
56+
CuSparseArrayCSR(repeat(A.rowPtr, 1, rs...),
57+
repeat(A.colVal, 1, rs...),
58+
repeat(A.nzVal, 1, rs...),
59+
(size(A)[1:2]..., [size(A,i+2)*rs[i] for i=1:length(rs)]...))
60+
end
61+
62+
# scalar addition/subtraction, scalar mul/div (see interfaces.jl +412)
63+
64+
# chkmmdims (see util.jl)

lib/cusparse/device.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,22 @@ Base.length(g::CuSparseDeviceMatrixCOO) = prod(g.dims)
7272
Base.size(g::CuSparseDeviceMatrixCOO) = g.dims
7373
SparseArrays.nnz(g::CuSparseDeviceMatrixCOO) = g.nnz
7474

75+
struct CuSparseDeviceArrayCSR{Tv, Ti, N, M, A} <: AbstractSparseArray{Tv, Ti, N}
76+
rowPtr::CuDeviceArray{Ti, M, A}
77+
colVal::CuDeviceArray{Ti, M, A}
78+
nzVal::CuDeviceArray{Tv, M, A}
79+
dims::NTuple{N, Int}
80+
nnz::Ti
81+
end
82+
83+
function CuSparseDeviceArrayCSR{Tv, Ti, N, A}(rowPtr::CuArray{<:Integer, M}, colVal::CuArray{<:Integer, M}, nzVal::CuArray{Tv, M}, dims::NTuple{N,<:Integer}) where {Tv, Ti<:Integer, M, N, A}
84+
@assert M == N - 1 "CuSparseDeviceArrayCSR requires ndims(rowPtr) == ndims(colVal) == ndims(nzVal) == length(dims) - 1"
85+
CuSparseDeviceArrayCSR{Tv, Ti, N, M, A}(rowPtr, colVal, nzVal, dims, length(nzVal))
86+
end
87+
88+
Base.length(g::CuSparseDeviceArrayCSR) = prod(g.dims)
89+
Base.size(g::CuSparseDeviceArrayCSR) = g.dims
90+
SparseArrays.nnz(g::CuSparseDeviceArrayCSR) = g.nnz
7591

7692
# input/output
7793

@@ -108,3 +124,10 @@ function Base.show(io::IO, ::MIME"text/plain", A::CuSparseDeviceMatrixCOO)
108124
println(io, " colInd: $(A.colInd)")
109125
print(io, " nzVal: $(A.nzVal)")
110126
end
127+
128+
function Base.show(io::IO, ::MIME"text/plain", A::CuSparseDeviceArrayCSR)
129+
println(io, "$(length(A))-element device sparse array CSR at:")
130+
println(io, " rowPtr: $(A.rowPtr)")
131+
println(io, " colVal: $(A.colVal)")
132+
print(io, " nzVal: $(A.nzVal)")
133+
end

lib/cusparse/generic.jl

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
export gather!, scatter!, axpby!, rot!
44
export vv!, sv!, sm!, gemm, gemm!, sddmm!
5+
export bmm!
56

67
## API functions
78

@@ -227,6 +228,89 @@ function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::Union{CuS
227228
descB = CuDenseMatrixDescriptor(B)
228229
descC = CuDenseMatrixDescriptor(C)
229230

231+
# cusparseDnMatSetStridedBatch(descB, size(B,3), size(B,1)*size(B,2))
232+
# cusparseDnMatSetStridedBatch(descB, size(B,3), size(B,1)*size(B,2))
233+
# batchsize = length(nonzeros(A)) ÷ nnz(A)
234+
# if batchsize > 1
235+
# cusparseCsrSetStridedBatch(obj, batchsize, 0, nnz(A))
236+
# end
237+
238+
function bufferSize()
239+
out = Ref{Csize_t}()
240+
cusparseSpMM_bufferSize(
241+
handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta),
242+
descC, T, algo, out)
243+
return out[]
244+
end
245+
with_workspace(bufferSize) do buffer
246+
# Uncomment if we find a way to reuse the buffer (issue #1362)
247+
# cusparseSpMM_preprocess(
248+
# handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta),
249+
# descC, T, algo, buffer)
250+
# end
251+
cusparseSpMM(
252+
handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta),
253+
descC, T, algo, buffer)
254+
end
255+
return C
256+
end
257+
258+
function bmm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseArrayCSR{T,Ti,N},
259+
B::DenseCuArray{T,N}, beta::Number, C::DenseCuArray{T,N}, index::SparseChar, algo::cusparseSpMMAlg_t=CUSPARSE_SPMM_ALG_DEFAULT) where {T,Ti,N}
260+
Ar = reshape(A, :, :, :)
261+
Br = reshape(B, size(B,1), size(B,2), :)
262+
Cr = reshape(C, size(C,1), size(C,2), :)
263+
bmm!(transa, transb, alpha, Ar, Br, beta, Cr, index, algo)
264+
return C
265+
end
266+
267+
# batched sparse * dense -> dense matmul
268+
function bmm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseArrayCSR{T,Ti,3},
269+
B::DenseCuArray{T,3}, beta::Number, C::DenseCuArray{T,3}, index::SparseChar, algo::cusparseSpMMAlg_t=CUSPARSE_SPMM_ALG_DEFAULT) where {T,Ti}
270+
271+
if CUSPARSE.version() < v"11.7.2"
272+
throw(ErrorException("Batched dense-matrix times batched sparse-matrix (bmm!) requires a CUSPARSE version ≥ 11.7.2 (yours: $(CUSPARSE.version()))."))
273+
end
274+
275+
276+
# Support transa = 'C' and `transb = 'C' for real matrices
277+
transa = T <: Real && transa == 'C' ? 'T' : transa
278+
transb = T <: Real && transb == 'C' ? 'T' : transb
279+
280+
m, k = size(A)[1:2]
281+
n, bc = size(C)[2:3]
282+
b = max(size(A, 3), size(B, 3))
283+
284+
if b != bc
285+
throw(ArgumentError("C must have same batch-dimension as max(size(A,3)=$(size(A,3)), size(B,3)=$(size(B,3))), got $(size(C,3))."))
286+
end
287+
288+
if n == 1 && b > 1
289+
throw(ArgumentError("bmm! does not work for n==1 and b>1 due to CUDA error."))
290+
end
291+
292+
if transa == 'N' && transb == 'N'
293+
chkbmmdims(B,C,k,n,m,n)
294+
elseif transa == 'N' && transb != 'N'
295+
chkbmmdims(B,C,n,k,m,n)
296+
elseif transa != 'N' && transb == 'N'
297+
chkbmmdims(B,C,m,n,k,n)
298+
elseif transa != 'N' && transb != 'N'
299+
chkbmmdims(B,C,n,m,k,n)
300+
end
301+
302+
descA = CuSparseMatrixDescriptor(A, index)
303+
descB = CuDenseMatrixDescriptor(B)
304+
descC = CuDenseMatrixDescriptor(C)
305+
306+
cusparseCsrSetStridedBatch(descA, b, ptrstride(A), valstride(A))
307+
308+
strideB = stride(B, 3)
309+
cusparseDnMatSetStridedBatch(descB, b, strideB)
310+
311+
strideC = stride(C, 3)
312+
cusparseDnMatSetStridedBatch(descC, b, strideC)
313+
230314
function bufferSize()
231315
out = Ref{Csize_t}()
232316
cusparseSpMM_bufferSize(

0 commit comments

Comments
 (0)