Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit c8d9a9b

Browse files
authored
Merge pull request #572 from janEbert/sparse
Implement a few sparse array basics
2 parents e6b5376 + 8acbd05 commit c8d9a9b

File tree

2 files changed

+105
-2
lines changed

2 files changed

+105
-2
lines changed

src/sparse/array.jl

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@ export CuSparseMatrixCSC, CuSparseMatrixCSR,
77
CuSparseVector
88

99
import Base: length, size, ndims, eltype, similar, pointer, stride,
10-
copy, convert, reinterpret, show, summary, copyto!, get!, fill!, collect
10+
copy, convert, reinterpret, show, summary, copyto!, getindex, get!, fill!, collect
1111

1212
using LinearAlgebra
1313
import LinearAlgebra: BlasFloat, Hermitian, HermOrSym, issymmetric, Transpose, Adjoint,
1414
ishermitian, istriu, istril, Symmetric, UpperTriangular, LowerTriangular
1515

1616
using SparseArrays
17-
import SparseArrays: sparse, SparseMatrixCSC
17+
import SparseArrays: sparse, SparseMatrixCSC, nnz, nonzeros, nonzeroinds,
18+
_spgetindex
1819

1920
abstract type AbstractCuSparseArray{Tv, N} <: AbstractSparseArray{Tv, Cint, N} end
2021
const AbstractCuSparseVector{Tv} = AbstractCuSparseArray{Tv,1}
@@ -166,6 +167,11 @@ function size(g::CuSparseMatrix, d::Integer)
166167
end
167168
end
168169

170+
nnz(g::AbstractCuSparseArray) = g.nnz
171+
nonzeros(g::AbstractCuSparseArray) = g.nzVal
172+
173+
nonzeroinds(g::AbstractCuSparseVector) = g.iPtr
174+
169175
issymmetric(M::Union{CuSparseMatrixCSC,CuSparseMatrixCSR}) = false
170176
ishermitian(M::Union{CuSparseMatrixCSC,CuSparseMatrixCSR}) = false
171177
issymmetric(M::Symmetric{CuSparseMatrixCSC}) = true
@@ -177,6 +183,67 @@ istriu(M::LowerTriangular{T,S}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix}
177183
istril(M::LowerTriangular{T,S}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix} = true
178184
eltype(g::CuSparseMatrix{T}) where T = T
179185

186+
# getindex (mostly adapted from stdlib/SparseArrays)
187+
188+
# Translations
189+
getindex(A::AbstractCuSparseVector, ::Colon) = copy(A)
190+
getindex(A::AbstractCuSparseMatrix, ::Colon, ::Colon) = copy(A)
191+
getindex(A::AbstractCuSparseMatrix, i, ::Colon) = getindex(A, i, 1:size(A, 2))
192+
getindex(A::AbstractCuSparseMatrix, ::Colon, i) = getindex(A, 1:size(A, 1), i)
193+
getindex(A::AbstractCuSparseMatrix, I::Tuple{Integer,Integer}) = getindex(A, I[1], I[2])
194+
195+
# Column slices
196+
function getindex(x::CuSparseMatrixCSC, ::Colon, j::Integer)
197+
checkbounds(x, :, j)
198+
r1 = convert(Int, x.colPtr[j])
199+
r2 = convert(Int, x.colPtr[j+1]) - 1
200+
CuSparseVector(x.rowVal[r1:r2], x.nzVal[r1:r2], size(x, 1))
201+
end
202+
203+
function getindex(x::CuSparseMatrixCSR, i::Integer, ::Colon)
204+
checkbounds(x, :, i)
205+
c1 = convert(Int, x.rowPtr[i])
206+
c2 = convert(Int, x.rowPtr[i+1]) - 1
207+
CuSparseVector(x.colVal[c1:c2], x.nzVal[c1:c2], size(x, 2))
208+
end
209+
210+
# Row slices
211+
# TODO optimize
212+
getindex(A::CuSparseMatrixCSC, i::Integer, ::Colon) = CuSparseVector(sparse(A[i, 1:end]))
213+
# TODO optimize
214+
getindex(A::CuSparseMatrixCSR, ::Colon, j::Integer) = CuSparseVector(sparse(A[1:end, j]))
215+
216+
function getindex(A::CuSparseMatrixCSC{T}, i0::Integer, i1::Integer) where T
217+
m, n = size(A)
218+
if !(1 <= i0 <= m && 1 <= i1 <= n)
219+
throw(BoundsError())
220+
end
221+
r1 = Int(A.colPtr[i1])
222+
r2 = Int(A.colPtr[i1+1]-1)
223+
(r1 > r2) && return zero(T)
224+
r1 = searchsortedfirst(A.rowVal, i0, r1, r2, Base.Order.Forward)
225+
((r1 > r2) || (A.rowVal[r1] != i0)) ? zero(T) : A.nzVal[r1]
226+
end
227+
228+
function getindex(A::CuSparseMatrixCSR{T}, i0::Integer, i1::Integer) where T
229+
m, n = size(A)
230+
if !(1 <= i0 <= m && 1 <= i1 <= n)
231+
throw(BoundsError())
232+
end
233+
c1 = Int(A.rowPtr[i0])
234+
c2 = Int(A.rowPtr[i0+1]-1)
235+
(c1 > c2) && return zero(T)
236+
c1 = searchsortedfirst(A.colVal, i1, c1, c2, Base.Order.Forward)
237+
((c1 > c2) || (A.colVal[c1] != i1)) ? zero(T) : A.nzVal[c1]
238+
end
239+
240+
# Called for indexing into `CuSparseVector`s
241+
function _spgetindex(m::Integer, nzind::CuVector{Ti}, nzval::CuVector{Tv},
242+
i::Integer) where {Tv,Ti}
243+
ii = searchsortedfirst(nzind, convert(Ti, i))
244+
(ii <= m && nzind[ii] == i) ? nzval[ii] : zero(Tv)
245+
end
246+
180247
function collect(Vec::CuSparseVector)
181248
SparseVector(Vec.dims[1], collect(Vec.iPtr), collect(Vec.nzVal))
182249
end

test/sparse.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,19 @@ blockdim = 5
1818
@test size(d_x,1) == m
1919
@test size(d_x,2) == 1
2020
@test ndims(d_x) == 1
21+
CuArrays.@allowscalar begin
22+
@test Array(d_x[:]) == x[:]
23+
@test d_x[firstindex(d_x)] == x[firstindex(x)]
24+
@test d_x[div(end, 2)] == x[div(end, 2)]
25+
@test d_x[end] == x[end]
26+
@test Array(d_x[firstindex(d_x):end]) == x[firstindex(x):end]
27+
end
28+
@test_throws BoundsError d_x[firstindex(d_x) - 1]
29+
@test_throws BoundsError d_x[end + 1]
30+
@test nnz(d_x) == nnz(x)
31+
@test Array(nonzeros(d_x)) == nonzeros(x)
32+
@test Array(SparseArrays.nonzeroinds(d_x)) == SparseArrays.nonzeroinds(x)
33+
@test nnz(d_x) == length(nonzeros(d_x))
2134
x = sprand(m,n,0.2)
2235
d_x = CuSparseMatrixCSC(x)
2336
@test length(d_x) == m*n
@@ -26,6 +39,29 @@ blockdim = 5
2639
@test size(d_x,2) == n
2740
@test size(d_x,3) == 1
2841
@test ndims(d_x) == 2
42+
CuArrays.@allowscalar begin
43+
@test Array(d_x[:]) == x[:]
44+
@test d_x[firstindex(d_x)] == x[firstindex(x)]
45+
@test d_x[div(end, 2)] == x[div(end, 2)]
46+
@test d_x[end] == x[end]
47+
@test d_x[firstindex(d_x), firstindex(d_x)] == x[firstindex(x), firstindex(x)]
48+
@test d_x[div(end, 2), div(end, 2)] == x[div(end, 2), div(end, 2)]
49+
@test d_x[end, end] == x[end, end]
50+
@test Array(d_x[firstindex(d_x):end, firstindex(d_x):end]) == x[:, :]
51+
end
52+
@test_throws BoundsError d_x[firstindex(d_x) - 1]
53+
@test_throws BoundsError d_x[end + 1]
54+
@test_throws BoundsError d_x[firstindex(d_x) - 1, firstindex(d_x) - 1]
55+
@test_throws BoundsError d_x[end + 1, end + 1]
56+
@test_throws BoundsError d_x[firstindex(d_x) - 1:end + 1, :]
57+
@test_throws BoundsError d_x[firstindex(d_x) - 1, :]
58+
@test_throws BoundsError d_x[end + 1, :]
59+
@test_throws BoundsError d_x[:, firstindex(d_x) - 1:end + 1]
60+
@test_throws BoundsError d_x[:, firstindex(d_x) - 1]
61+
@test_throws BoundsError d_x[:, end + 1]
62+
@test nnz(d_x) == nnz(x)
63+
@test Array(nonzeros(d_x)) == nonzeros(x)
64+
@test nnz(d_x) == length(nonzeros(d_x))
2965
@test !issymmetric(d_x)
3066
@test !ishermitian(d_x)
3167
@test_throws ArgumentError size(d_x,0)

0 commit comments

Comments
 (0)