Skip to content

Commit 26b7748

Browse files
devmotionDavid Widmann
andauthored
Move SparseArrays support to an extension (#638)
Co-authored-by: David Widmann <devmotion@noreply.users.github.com>
1 parent 987b83a commit 26b7748

File tree

5 files changed

+117
-104
lines changed

5 files changed

+117
-104
lines changed

Project.toml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.17.0"
3+
version = "1.18.0"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
99

10+
[weakdeps]
11+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
12+
13+
[extensions]
14+
ChainRulesCoreSparseArraysExt = "SparseArrays"
15+
1016
[compat]
1117
BenchmarkTools = "0.5"
1218
Compat = "2, 3, 4"
@@ -19,8 +25,9 @@ julia = "1.6"
1925
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
2026
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2127
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
28+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2229
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2330
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2431

2532
[targets]
26-
test = ["Test", "BenchmarkTools", "FiniteDifferences", "OffsetArrays", "StaticArrays"]
33+
test = ["Test", "BenchmarkTools", "FiniteDifferences", "OffsetArrays", "SparseArrays", "StaticArrays"]

ext/ChainRulesCoreSparseArraysExt.jl

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
module ChainRulesCoreSparseArraysExt
2+
3+
using ChainRulesCore
4+
using ChainRulesCore: project_type, _projection_mismatch
5+
using SparseArrays: SparseVector, SparseMatrixCSC, nzrange, rowvals
6+
7+
ChainRulesCore.is_inplaceable_destination(::SparseVector) = true
8+
ChainRulesCore.is_inplaceable_destination(::SparseMatrixCSC) = true
9+
10+
# Word from on high is that we should regard all un-stored values of sparse arrays as
11+
# structural zeros. Thus ProjectTo needs to store nzind, and get only those.
12+
# This implementation very naiive, can probably be made more efficient.
13+
14+
function ChainRulesCore.ProjectTo(x::SparseVector{T}) where {T<:Number}
15+
return ProjectTo{SparseVector}(;
16+
element=ProjectTo(zero(T)), nzind=x.nzind, axes=axes(x)
17+
)
18+
end
19+
function (project::ProjectTo{SparseVector})(dx::AbstractArray)
20+
dy = if axes(dx) == project.axes
21+
dx
22+
else
23+
if size(dx, 1) != length(project.axes[1])
24+
throw(_projection_mismatch(project.axes, size(dx)))
25+
end
26+
reshape(dx, project.axes)
27+
end
28+
nzval = map(i -> project.element(dy[i]), project.nzind)
29+
return SparseVector(length(dx), project.nzind, nzval)
30+
end
31+
function (project::ProjectTo{SparseVector})(dx::SparseVector)
32+
if size(dx) != map(length, project.axes)
33+
throw(_projection_mismatch(project.axes, size(dx)))
34+
end
35+
# When sparsity pattern is unchanged, all the time is in checking this,
36+
# perhaps some simple hash/checksum might be good enough?
37+
samepattern = project.nzind == dx.nzind
38+
# samepattern = length(project.nzind) == length(dx.nzind)
39+
if eltype(dx) <: project_type(project.element) && samepattern
40+
return dx
41+
elseif samepattern
42+
nzval = map(project.element, dx.nzval)
43+
SparseVector(length(dx), dx.nzind, nzval)
44+
else
45+
nzind = project.nzind
46+
# Or should we intersect? Can this exploit sorting?
47+
# nzind = intersect(project.nzind, dx.nzind)
48+
nzval = map(i -> project.element(dx[i]), nzind)
49+
return SparseVector(length(dx), nzind, nzval)
50+
end
51+
end
52+
53+
function ChainRulesCore.ProjectTo(x::SparseMatrixCSC{T}) where {T<:Number}
54+
return ProjectTo{SparseMatrixCSC}(;
55+
element=ProjectTo(zero(T)),
56+
axes=axes(x),
57+
rowval=rowvals(x),
58+
nzranges=nzrange.(Ref(x), axes(x, 2)),
59+
colptr=x.colptr,
60+
)
61+
end
62+
# You need not really store nzranges, you can get them from colptr -- TODO
63+
# nzrange(S::AbstractSparseMatrixCSC, col::Integer) = getcolptr(S)[col]:(getcolptr(S)[col+1]-1)
64+
function (project::ProjectTo{SparseMatrixCSC})(dx::AbstractArray)
65+
dy = if axes(dx) == project.axes
66+
dx
67+
else
68+
if size(dx) != (length(project.axes[1]), length(project.axes[2]))
69+
throw(_projection_mismatch(project.axes, size(dx)))
70+
end
71+
reshape(dx, project.axes)
72+
end
73+
nzval = Vector{project_type(project.element)}(undef, length(project.rowval))
74+
k = 0
75+
for col in project.axes[2]
76+
for i in project.nzranges[col]
77+
row = project.rowval[i]
78+
val = dy[row, col]
79+
nzval[k += 1] = project.element(val)
80+
end
81+
end
82+
m, n = map(length, project.axes)
83+
return SparseMatrixCSC(m, n, project.colptr, project.rowval, nzval)
84+
end
85+
86+
function (project::ProjectTo{SparseMatrixCSC})(dx::SparseMatrixCSC)
87+
if size(dx) != map(length, project.axes)
88+
throw(_projection_mismatch(project.axes, size(dx)))
89+
end
90+
samepattern = dx.colptr == project.colptr && dx.rowval == project.rowval
91+
# samepattern = length(dx.colptr) == length(project.colptr) && dx.colptr[end] == project.colptr[end]
92+
if eltype(dx) <: project_type(project.element) && samepattern
93+
return dx
94+
elseif samepattern
95+
nzval = map(project.element, dx.nzval)
96+
m, n = size(dx)
97+
return SparseMatrixCSC(m, n, dx.colptr, dx.rowval, nzval)
98+
else
99+
invoke(project, Tuple{AbstractArray}, dx)
100+
end
101+
end
102+
103+
end # module

src/ChainRulesCore.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ module ChainRulesCore
22
using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize!
33
using Base.Meta
44
using LinearAlgebra
5-
using SparseArrays: SparseVector, SparseMatrixCSC
65
using Compat: hasfield, hasproperty
76

87
export frule, rrule # core function
@@ -36,4 +35,9 @@ include("ignore_derivatives.jl")
3635

3736
include("deprecated.jl")
3837

38+
# SparseArrays support on Julia < 1.9
39+
if !isdefined(Base, :get_extension)
40+
include("../ext/ChainRulesCoreSparseArraysExt.jl")
41+
end
42+
3943
end # module

src/accumulation.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,6 @@ is_inplaceable_destination(::Any) = false
5656
is_inplaceable_destination(::Array) = true
5757
is_inplaceable_destination(:: Array{<:Integer}) = false
5858

59-
is_inplaceable_destination(::SparseVector) = true
60-
is_inplaceable_destination(::SparseMatrixCSC) = true
61-
6259
function is_inplaceable_destination(x::SubArray)
6360
alpha = is_inplaceable_destination(parent(x))
6461
beta = x.indices isa Tuple{Vararg{Union{Integer, Base.Slice, UnitRange}}}

src/projection.jl

Lines changed: 0 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -515,101 +515,3 @@ function (project::ProjectTo{Tridiagonal})(dx::AbstractArray)
515515
return Tridiagonal(dy)
516516
end
517517
# Note that backing(::Tridiagonal) doesn't work, https://github.com/JuliaDiff/ChainRulesCore.jl/issues/392
518-
519-
#####
520-
##### `SparseArrays`
521-
#####
522-
523-
using SparseArrays
524-
# Word from on high is that we should regard all un-stored values of sparse arrays as
525-
# structural zeros. Thus ProjectTo needs to store nzind, and get only those.
526-
# This implementation very naiive, can probably be made more efficient.
527-
528-
function ProjectTo(x::SparseVector{T}) where {T<:Number}
529-
return ProjectTo{SparseVector}(;
530-
element=ProjectTo(zero(T)), nzind=x.nzind, axes=axes(x)
531-
)
532-
end
533-
function (project::ProjectTo{SparseVector})(dx::AbstractArray)
534-
dy = if axes(dx) == project.axes
535-
dx
536-
else
537-
if size(dx, 1) != length(project.axes[1])
538-
throw(_projection_mismatch(project.axes, size(dx)))
539-
end
540-
reshape(dx, project.axes)
541-
end
542-
nzval = map(i -> project.element(dy[i]), project.nzind)
543-
return SparseVector(length(dx), project.nzind, nzval)
544-
end
545-
function (project::ProjectTo{SparseVector})(dx::SparseVector)
546-
if size(dx) != map(length, project.axes)
547-
throw(_projection_mismatch(project.axes, size(dx)))
548-
end
549-
# When sparsity pattern is unchanged, all the time is in checking this,
550-
# perhaps some simple hash/checksum might be good enough?
551-
samepattern = project.nzind == dx.nzind
552-
# samepattern = length(project.nzind) == length(dx.nzind)
553-
if eltype(dx) <: project_type(project.element) && samepattern
554-
return dx
555-
elseif samepattern
556-
nzval = map(project.element, dx.nzval)
557-
SparseVector(length(dx), dx.nzind, nzval)
558-
else
559-
nzind = project.nzind
560-
# Or should we intersect? Can this exploit sorting?
561-
# nzind = intersect(project.nzind, dx.nzind)
562-
nzval = map(i -> project.element(dx[i]), nzind)
563-
return SparseVector(length(dx), nzind, nzval)
564-
end
565-
end
566-
567-
function ProjectTo(x::SparseMatrixCSC{T}) where {T<:Number}
568-
return ProjectTo{SparseMatrixCSC}(;
569-
element=ProjectTo(zero(T)),
570-
axes=axes(x),
571-
rowval=rowvals(x),
572-
nzranges=nzrange.(Ref(x), axes(x, 2)),
573-
colptr=x.colptr,
574-
)
575-
end
576-
# You need not really store nzranges, you can get them from colptr -- TODO
577-
# nzrange(S::AbstractSparseMatrixCSC, col::Integer) = getcolptr(S)[col]:(getcolptr(S)[col+1]-1)
578-
function (project::ProjectTo{SparseMatrixCSC})(dx::AbstractArray)
579-
dy = if axes(dx) == project.axes
580-
dx
581-
else
582-
if size(dx) != (length(project.axes[1]), length(project.axes[2]))
583-
throw(_projection_mismatch(project.axes, size(dx)))
584-
end
585-
reshape(dx, project.axes)
586-
end
587-
nzval = Vector{project_type(project.element)}(undef, length(project.rowval))
588-
k = 0
589-
for col in project.axes[2]
590-
for i in project.nzranges[col]
591-
row = project.rowval[i]
592-
val = dy[row, col]
593-
nzval[k += 1] = project.element(val)
594-
end
595-
end
596-
m, n = map(length, project.axes)
597-
return SparseMatrixCSC(m, n, project.colptr, project.rowval, nzval)
598-
end
599-
600-
function (project::ProjectTo{SparseMatrixCSC})(dx::SparseMatrixCSC)
601-
if size(dx) != map(length, project.axes)
602-
throw(_projection_mismatch(project.axes, size(dx)))
603-
end
604-
samepattern = dx.colptr == project.colptr && dx.rowval == project.rowval
605-
# samepattern = length(dx.colptr) == length(project.colptr) && dx.colptr[end] == project.colptr[end]
606-
if eltype(dx) <: project_type(project.element) && samepattern
607-
return dx
608-
elseif samepattern
609-
nzval = map(project.element, dx.nzval)
610-
m, n = size(dx)
611-
return SparseMatrixCSC(m, n, dx.colptr, dx.rowval, nzval)
612-
else
613-
invoke(project, Tuple{AbstractArray}, dx)
614-
end
615-
end

0 commit comments

Comments
 (0)