Skip to content

Commit 38fe81f

Browse files
Added few SparseArrays functions (#2254)
1 parent fc1d509 commit 38fe81f

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

lib/cusparse/array.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,26 @@ SparseArrays.nonzeros(g::AbstractCuSparseArray) = g.nzVal
301301
SparseArrays.nonzeroinds(g::AbstractCuSparseVector) = g.iPtr
302302

303303
SparseArrays.rowvals(g::CuSparseMatrixCSC) = g.rowVal
304+
SparseArrays.getcolptr(S::CuSparseMatrixCSC) = S.colPtr
305+
306+
function SparseArrays.findnz(S::MT) where {MT <: AbstractCuSparseMatrix}
307+
S2 = CuSparseMatrixCOO(S)
308+
I = S2.rowInd
309+
J = S2.colInd
310+
V = S2.nzVal
311+
312+
# To make it compatible with the SparseArrays.jl version
313+
idxs = sortperm(J)
314+
I = I[idxs]
315+
J = J[idxs]
316+
V = V[idxs]
317+
318+
return (I, J, V)
319+
end
320+
321+
function SparseArrays.sparsevec(I::CuArray{Ti}, V::CuArray{Tv}, n::Integer) where {Ti,Tv}
322+
CuSparseVector(I, V, n)
323+
end
304324

305325
LinearAlgebra.issymmetric(M::Union{CuSparseMatrixCSC,CuSparseMatrixCSR}) = size(M, 1) == size(M, 2) ? norm(M - transpose(M), Inf) == 0 : false
306326
LinearAlgebra.ishermitian(M::Union{CuSparseMatrixCSC,CuSparseMatrixCSR}) = size(M, 1) == size(M, 2) ? norm(M - adjoint(M), Inf) == 0 : false

test/libraries/cusparse.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using CUDA.CUSPARSE
22

33
using LinearAlgebra
44
using SparseArrays
5-
using SparseArrays: nonzeroinds
5+
using SparseArrays: nonzeroinds, getcolptr
66

77
@test CUSPARSE.version() isa VersionNumber
88

@@ -106,6 +106,22 @@ blockdim = 5
106106
d_x = LowerTriangular(CuSparseMatrixCSC(x))
107107
@test !istriu(d_x)
108108
@test istril(d_x)
109+
110+
A = sprand(n, n, 0.2)
111+
d_A = CuSparseMatrixCSC(A)
112+
@test Array(getcolptr(d_A)) == getcolptr(A)
113+
i, j, v = findnz(A)
114+
d_i, d_j, d_v = findnz(d_A)
115+
@test Array(d_i) == i && Array(d_j) == j && Array(d_v) == v
116+
i = unique(sort(rand(1:n, 10)))
117+
vals = rand(length(i))
118+
d_i = CuArray(i)
119+
d_vals = CuArray(vals)
120+
v = sparsevec(i, vals, n)
121+
d_v = sparsevec(d_i, d_vals, n)
122+
@test Array(d_v.iPtr) == v.nzind
123+
@test Array(d_v.nzVal) == v.nzval
124+
@test d_v.len == v.n
109125
end
110126

111127
@testset "construction" begin

0 commit comments

Comments
 (0)