Skip to content

Commit 8a506f4

Browse files
mcognettadkarrasch
authored andcommitted
Avoid allocating diagonal zeros in sparse(::Diagonal) (#42577)
Co-authored-by: Daniel Karrasch <daniel.karrasch@posteo.de>
1 parent 71c4acc commit 8a506f4

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

src/sparsematrix.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -656,8 +656,33 @@ end
656656
SparseMatrixCSC(D::Diagonal{Tv}) where Tv = SparseMatrixCSC{Tv,Int}(D)
657657
function SparseMatrixCSC{Tv,Ti}(D::Diagonal) where {Tv,Ti}
658658
m = length(D.diag)
659-
return SparseMatrixCSC(m, m, Vector(Ti(1):Ti(m+1)), Vector(Ti(1):Ti(m)), Vector{Tv}(D.diag))
659+
m == 0 && return SparseMatrixCSC{Tv,Ti}(zeros(Tv, 0, 0))
660+
661+
nz = count(!iszero, D.diag)
662+
nz_counter = 1
663+
664+
rowval = Vector{Ti}(undef, nz)
665+
nzval = Vector{Tv}(undef, nz)
666+
667+
nz == 0 && return SparseMatrixCSC{Tv,Ti}(m, m, ones(Ti, m+1), rowval, nzval)
668+
669+
colptr = Vector{Ti}(undef, m+1)
670+
671+
@inbounds for i=1:m
672+
if !iszero(D.diag[i])
673+
colptr[i] = nz_counter
674+
rowval[nz_counter] = i
675+
nzval[nz_counter] = D.diag[i]
676+
nz_counter += 1
677+
else
678+
colptr[i] = nz_counter
679+
end
680+
end
681+
colptr[end] = nz_counter
682+
683+
return SparseMatrixCSC{Tv,Ti}(m, m, colptr, rowval, nzval)
660684
end
685+
661686
SparseMatrixCSC(M::AbstractMatrix{Tv}) where {Tv} = SparseMatrixCSC{Tv,Int}(M)
662687
SparseMatrixCSC{Tv}(M::AbstractMatrix{Tv}) where {Tv} = SparseMatrixCSC{Tv,Int}(M)
663688
function SparseMatrixCSC{Tv,Ti}(M::AbstractMatrix) where {Tv,Ti}

test/sparse.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,6 +1605,14 @@ end
16051605
@test sparse(s) == s
16061606
end
16071607

1608+
@testset "avoid allocation for zeros in diagonal" begin
1609+
x = [1, 0, 0, 5, 0]
1610+
d = Diagonal(x)
1611+
s = sparse(d)
1612+
@test s == d
1613+
@test nnz(s) == 2
1614+
end
1615+
16081616
@testset "error conditions for reshape, and dropdims" begin
16091617
local A = sprand(Bool, 5, 5, 0.2)
16101618
@test_throws DimensionMismatch reshape(A,(20, 2))

0 commit comments

Comments
 (0)