Skip to content

Commit 285c7a5

Browse files
authored
Merge pull request #1409 from JuliaGPU/tb/sparse_bsr
CUSPARSE BSR improvements
2 parents a50e85c + 799faa9 commit 285c7a5

File tree

5 files changed

+49
-33
lines changed

5 files changed

+49
-33
lines changed

lib/cusparse/array.jl

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ mutable struct CuSparseMatrixBSR{Tv, Ti} <: AbstractCuSparseMatrix{Tv, Ti}
102102
dims::NTuple{2,Int}
103103
blockDim::Ti
104104
dir::SparseChar
105-
nnz::Ti
105+
nnzb::Ti
106106

107107
function CuSparseMatrixBSR{Tv, Ti}(rowPtr::CuVector{<:Integer}, colVal::CuVector{<:Integer},
108108
nzVal::CuVector, dims::NTuple{2,<:Integer},
@@ -268,6 +268,8 @@ LinearAlgebra.istril(M::LowerTriangular{T,S}) where {T<:BlasFloat, S<:AbstractCu
268268

269269
Hermitian{T}(Mat::CuSparseMatrix{T}) where T = Hermitian{T,typeof(Mat)}(Mat,'U')
270270

271+
SparseArrays.nnz(g::CuSparseMatrixBSR) = g.nnzb * g.blockDim * g.blockDim
272+
271273

272274
## indexing
273275

@@ -297,6 +299,13 @@ end
297299
Base.getindex(A::CuSparseMatrixCSC, i::Integer, ::Colon) = CuSparseVector(sparse(A[i, 1:end])) # TODO: optimize
298300
Base.getindex(A::CuSparseMatrixCSR, ::Colon, j::Integer) = CuSparseVector(sparse(A[1:end, j])) # TODO: optimize
299301

302+
function Base.getindex(A::CuSparseVector{Tv, Ti}, i::Integer) where {Tv, Ti}
303+
@boundscheck checkbounds(A, i)
304+
ii = searchsortedfirst(A.iPtr, convert(Ti, i))
305+
(ii > nnz(A) || A.iPtr[ii] != i) && return zero(Tv)
306+
A.nzVal[ii]
307+
end
308+
300309
function Base.getindex(A::CuSparseMatrixCSC{T}, i0::Integer, i1::Integer) where T
301310
@boundscheck checkbounds(A, i0, i1)
302311
r1 = Int(A.colPtr[i1])
@@ -327,10 +336,17 @@ function Base.getindex(A::CuSparseMatrixCOO{T}, i0::Integer, i1::Integer) where
327336
nonzeros(A)[c1]
328337
end
329338

330-
function SparseArrays._spgetindex(m::Integer, nzind::CuVector{Ti}, nzval::CuVector{Tv},
331-
i::Integer) where {Tv,Ti}
332-
ii = searchsortedfirst(nzind, convert(Ti, i))
333-
(ii <= m && nzind[ii] == i) ? nzval[ii] : zero(Tv)
339+
function Base.getindex(A::CuSparseMatrixBSR{T}, i0::Integer, i1::Integer) where T
340+
@boundscheck checkbounds(A, i0, i1)
341+
i0_block, i0_idx = fldmod1(i0, A.blockDim)
342+
i1_block, i1_idx = fldmod1(i1, A.blockDim)
343+
block_idx = (i0_idx - 1) * A.blockDim + i1_idx - 1
344+
c1 = Int(A.rowPtr[i0_block])
345+
c2 = Int(A.rowPtr[i0_block+1]-1)
346+
(c1 > c2) && return zero(T)
347+
c1 = searchsortedfirst(A.colVal, i1_block, c1, c2, Base.Order.Forward)
348+
(c1 > c2 || A.colVal[c1] != i1_block) && return zero(T)
349+
nonzeros(A)[c1+block_idx]
334350
end
335351

336352

@@ -407,7 +423,7 @@ function Base.copyto!(dst::CuSparseVector, src::CuSparseVector)
407423
end
408424
copyto!(nonzeroinds(dst), nonzeroinds(src))
409425
copyto!(nonzeros(dst), nonzeros(src))
410-
dst.nnz = nnz(src)
426+
dst.nnz = src.nnz
411427
dst
412428
end
413429

@@ -418,7 +434,7 @@ function Base.copyto!(dst::CuSparseMatrixCSC, src::CuSparseMatrixCSC)
418434
copyto!(dst.colPtr, src.colPtr)
419435
copyto!(rowvals(dst), rowvals(src))
420436
copyto!(nonzeros(dst), nonzeros(src))
421-
dst.nnz = nnz(src)
437+
dst.nnz = src.nnz
422438
dst
423439
end
424440

@@ -429,7 +445,7 @@ function Base.copyto!(dst::CuSparseMatrixCSR, src::CuSparseMatrixCSR)
429445
copyto!(dst.rowPtr, src.rowPtr)
430446
copyto!(dst.colVal, src.colVal)
431447
copyto!(nonzeros(dst), nonzeros(src))
432-
dst.nnz = nnz(src)
448+
dst.nnz = src.nnz
433449
dst
434450
end
435451

@@ -441,7 +457,7 @@ function Base.copyto!(dst::CuSparseMatrixBSR, src::CuSparseMatrixBSR)
441457
copyto!(dst.colVal, src.colVal)
442458
copyto!(nonzeros(dst), nonzeros(src))
443459
dst.dir = src.dir
444-
dst.nnz = nnz(src)
460+
dst.nnzb = src.nnzb
445461
dst
446462
end
447463

@@ -452,7 +468,7 @@ function Base.copyto!(dst::CuSparseMatrixCOO, src::CuSparseMatrixCOO)
452468
copyto!(dst.rowInd, src.rowInd)
453469
copyto!(dst.colInd, src.colInd)
454470
copyto!(nonzeros(dst), nonzeros(src))
455-
dst.nnz = nnz(src)
471+
dst.nnz = src.nnz
456472
dst
457473
end
458474

@@ -537,7 +553,7 @@ function Adapt.adapt_structure(to::CUDA.Adaptor, x::CuSparseMatrixBSR)
537553
adapt(to, x.colVal),
538554
adapt(to, x.nzVal),
539555
size(x), x.blockDim,
540-
x.dir, x.nnz
556+
x.dir, x.nnzb
541557
)
542558
end
543559

lib/cusparse/conversions.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ for (fname,elty) in ((:cusparseScsr2bsr, :Float32),
216216
indc::SparseChar='O')
217217
m,n = size(csr)
218218
nnz_ref = Ref{Cint}(1)
219-
mb = div((m + blockDim - 1),blockDim)
220-
nb = div((n + blockDim - 1),blockDim)
219+
mb = cld(m, blockDim)
220+
nb = cld(n, blockDim)
221221
bsrRowPtr = CUDA.zeros(Cint,mb + 1)
222222
cudesca = CuMatrixDescriptor('G', 'L', 'N', inda)
223223
cudescc = CuMatrixDescriptor('G', 'L', 'N', indc)
@@ -242,19 +242,20 @@ for (fname,elty) in ((:cusparseSbsr2csr, :Float32),
242242
function CuSparseMatrixCSR{$elty}(bsr::CuSparseMatrixBSR{$elty};
243243
inda::SparseChar='O', indc::SparseChar='O')
244244
m,n = size(bsr)
245-
mb = div(m,bsr.blockDim)
246-
nb = div(n,bsr.blockDim)
247-
nnzVal = nnz(bsr) * bsr.blockDim * bsr.blockDim
245+
mb = cld(m, bsr.blockDim)
246+
nb = cld(n, bsr.blockDim)
248247
cudesca = CuMatrixDescriptor('G', 'L', 'N', inda)
249248
cudescc = CuMatrixDescriptor('G', 'L', 'N', indc)
250249
csrRowPtr = CUDA.zeros(Cint, m + 1)
251-
csrColInd = CUDA.zeros(Cint, nnzVal)
252-
csrNzVal = CUDA.zeros($elty, nnzVal)
250+
csrColInd = CUDA.zeros(Cint, nnz(bsr))
251+
csrNzVal = CUDA.zeros($elty, nnz(bsr))
253252
$fname(handle(), bsr.dir, mb, nb,
254253
cudesca, nonzeros(bsr), bsr.rowPtr, bsr.colVal,
255254
bsr.blockDim, cudescc, csrNzVal, csrRowPtr,
256255
csrColInd)
257-
CuSparseMatrixCSR(csrRowPtr, csrColInd, csrNzVal, size(bsr))
256+
# XXX: the size here may not match the expected size, when the matrix dimension
257+
# is not a multiple of the block dimension!
258+
CuSparseMatrixCSR(csrRowPtr, csrColInd, csrNzVal, (mb*bsr.blockDim, nb*bsr.blockDim))
258259
end
259260
end
260261
end

lib/cusparse/level2.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ for (bname,aname,sname,elty) in ((:cusparseSbsrsv2_bufferSize, :cusparseSbsrsv2_
7070
if m != n
7171
throw(DimensionMismatch("A must be square, but has dimensions ($m,$n)!"))
7272
end
73-
mb = div(m,A.blockDim)
73+
mb = cld(m, A.blockDim)
7474
mX = length(X)
7575
if mX != m
7676
throw(DimensionMismatch("X must have length $m, but has length $mX"))
@@ -80,21 +80,21 @@ for (bname,aname,sname,elty) in ((:cusparseSbsrsv2_bufferSize, :cusparseSbsrsv2_
8080

8181
function bufferSize()
8282
out = Ref{Cint}(1)
83-
$bname(handle(), A.dir, transa, mb, nnz(A),
83+
$bname(handle(), A.dir, transa, mb, A.nnzb,
8484
desc, nonzeros(A), A.rowPtr, A.colVal, A.blockDim,
8585
info[1], out)
8686
return out[]
8787
end
8888
with_workspace(bufferSize) do buffer
89-
$aname(handle(), A.dir, transa, mb, nnz(A),
89+
$aname(handle(), A.dir, transa, mb, A.nnzb,
9090
desc, nonzeros(A), A.rowPtr, A.colVal, A.blockDim,
9191
info[1], CUSPARSE_SOLVE_POLICY_USE_LEVEL, buffer)
9292
posit = Ref{Cint}(1)
9393
cusparseXbsrsv2_zeroPivot(handle(), info[1], posit)
9494
if posit[] >= 0
9595
error("Structural/numerical zero in A at ($(posit[]),$(posit[])))")
9696
end
97-
$sname(handle(), A.dir, transa, mb, nnz(A),
97+
$sname(handle(), A.dir, transa, mb, A.nnzb,
9898
alpha, desc, nonzeros(A), A.rowPtr, A.colVal,
9999
A.blockDim, info[1], X, X,
100100
CUSPARSE_SOLVE_POLICY_USE_LEVEL, buffer)

lib/cusparse/level3.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ for (fname,elty) in ((:cusparseSbsrmm, :Float32),
2828
index::SparseChar)
2929
desc = CuMatrixDescriptor('G', 'L', 'N', index)
3030
m,k = size(A)
31-
mb = div(m,A.blockDim)
32-
kb = div(k,A.blockDim)
31+
mb = cld(m, A.blockDim)
32+
kb = cld(k, A.blockDim)
3333
n = size(C)[2]
3434
if transa == 'N' && transb == 'N'
3535
chkmmdims(B,C,k,n,m,n)
@@ -43,7 +43,7 @@ for (fname,elty) in ((:cusparseSbsrmm, :Float32),
4343
ldb = max(1,stride(B,2))
4444
ldc = max(1,stride(C,2))
4545
$fname(handle(), A.dir,
46-
transa, transb, mb, n, kb, nnz(A),
46+
transa, transb, mb, n, kb, A.nnzb,
4747
alpha, desc, nonzeros(A),A.rowPtr, A.colVal,
4848
A.blockDim, B, ldb, beta, C, ldc)
4949
C
@@ -156,7 +156,7 @@ for (bname,aname,sname,elty) in ((:cusparseSbsrsm2_bufferSize, :cusparseSbsrsm2_
156156
if m != n
157157
throw(DimensionMismatch("A must be square, but has dimensions ($m,$n)!"))
158158
end
159-
mb = div(m,A.blockDim)
159+
mb = cld(m, A.blockDim)
160160
mX,nX = size(X)
161161
if transxy == 'N' && (mX != m)
162162
throw(DimensionMismatch(""))
@@ -171,14 +171,14 @@ for (bname,aname,sname,elty) in ((:cusparseSbsrsm2_bufferSize, :cusparseSbsrsm2_
171171
function bufferSize()
172172
out = Ref{Cint}(1)
173173
$bname(handle(), A.dir, transa, transxy,
174-
mb, nX, nnz(A), desc, nonzeros(A), A.rowPtr,
174+
mb, nX, A.nnzb, desc, nonzeros(A), A.rowPtr,
175175
A.colVal, A.blockDim, info[],
176176
out)
177177
return out[]
178178
end
179179
with_workspace(bufferSize) do buffer
180180
$aname(handle(), A.dir, transa, transxy,
181-
mb, nX, nnz(A), desc, nonzeros(A), A.rowPtr,
181+
mb, nX, A.nnzb, desc, nonzeros(A), A.rowPtr,
182182
A.colVal, A.blockDim, info[],
183183
CUSPARSE_SOLVE_POLICY_USE_LEVEL, buffer)
184184
posit = Ref{Cint}(1)
@@ -187,7 +187,7 @@ for (bname,aname,sname,elty) in ((:cusparseSbsrsm2_bufferSize, :cusparseSbsrsm2_
187187
error("Structural/numerical zero in A at ($(posit[]),$(posit[])))")
188188
end
189189
$sname(handle(), A.dir, transa, transxy, mb,
190-
nX, nnz(A), alpha, desc, nonzeros(A), A.rowPtr,
190+
nX, A.nnzb, alpha, desc, nonzeros(A), A.rowPtr,
191191
A.colVal, A.blockDim, info[], X, ldx, X, ldx,
192192
CUSPARSE_SOLVE_POLICY_USE_LEVEL, buffer)
193193
end

test/cusparse/device.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ using CUDA.CUSPARSE: CuSparseDeviceVector, CuSparseDeviceMatrixCSC, CuSparseDevi
2424
cuA = CuSparseMatrixCOO(A)
2525
@test cudaconvert(cuA) isa CuSparseDeviceMatrixCOO{Float64, Cint, 1}
2626

27-
# Roger-Luo: I'm not sure how to create a BSR matrix
28-
# cuA = CuSparseMatrixBSR(A)
29-
# @test cudaconvert(cuA) isa CuSparseDeviceMatrixBSR
27+
cuA = CuSparseMatrixBSR(A, 2)
28+
@test cudaconvert(cuA) isa CuSparseDeviceMatrixBSR
3029
end

0 commit comments

Comments
 (0)