Skip to content

Commit 619dc89

Browse files
Merge pull request #259 from avik-pal/ap/known_sparsity
Allow specifying colorvecs directly
2 parents 6bb4839 + a6731aa commit 619dc89

File tree

5 files changed

+140
-22
lines changed

5 files changed

+140
-22
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SparseDiffTools"
22
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
33
authors = ["Pankaj Mishra <pankajmishra1511@gmail.com>", "Chris Rackauckas <contact@chrisrackauckas.com>"]
4-
version = "2.5.2"
4+
version = "2.6.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/SparseDiffTools.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,9 @@ export update_coefficients, update_coefficients!, value!
8888
# High Level Interface: sparse_jacobian
8989
export AutoSparseEnzyme
9090

91-
export NoSparsityDetection,
92-
SymbolicsSparsityDetection, JacPrototypeSparsityDetection, AutoSparsityDetection
91+
export NoSparsityDetection, SymbolicsSparsityDetection, JacPrototypeSparsityDetection,
92+
PrecomputedJacobianColorvec, AutoSparsityDetection
9393
export sparse_jacobian, sparse_jacobian_cache, sparse_jacobian!
94+
export init_jacobian
9495

9596
end # module

src/highlevel/coloring.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,16 @@ struct NoMatrixColoring end
1616
# Prespecified Jacobian Structure
1717
function (alg::JacPrototypeSparsityDetection)(ad::AbstractSparseADType, args...; kwargs...)
1818
J = alg.jac_prototype
19-
reverse_mode = ad isa AbstractSparseReverseMode
20-
colorvec = matrix_colors(J, alg.alg; partition_by_rows = reverse_mode)
19+
colorvec = matrix_colors(J, alg.alg;
20+
partition_by_rows = ad isa AbstractSparseReverseMode)
21+
(nz_rows, nz_cols) = ArrayInterface.findstructralnz(J)
22+
return MatrixColoringResult(colorvec, J, nz_rows, nz_cols)
23+
end
24+
25+
# Prespecified Colorvecs
26+
function (alg::PrecomputedJacobianColorvec)(ad::AbstractSparseADType, args...; kwargs...)
27+
colorvec = _get_colorvec(alg, ad)
28+
J = alg.jac_prototype
2129
(nz_rows, nz_cols) = ArrayInterface.findstructralnz(J)
2230
return MatrixColoringResult(colorvec, J, nz_rows, nz_cols)
2331
end

src/highlevel/common.jl

Lines changed: 116 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,110 @@ abstract type AbstractSparsityDetection <: AbstractMaybeSparsityDetection end
99

1010
struct NoSparsityDetection <: AbstractMaybeSparsityDetection end
1111

12+
"""
13+
SymbolicsSparsityDetection(; alg = GreedyD1Color())
14+
15+
Use Symbolics to compute the sparsity pattern of the Jacobian. This requires `Symbolics.jl`
16+
to be explicitly loaded.
17+
18+
## Keyword Arguments
19+
20+
- `alg`: The algorithm used for computing the matrix colors
21+
22+
See Also: [JacPrototypeSparsityDetection](@ref), [PrecomputedJacobianColorvec](@ref)
23+
"""
1224
Base.@kwdef struct SymbolicsSparsityDetection{A <: ArrayInterface.ColoringAlgorithm} <:
1325
AbstractSparsityDetection
1426
alg::A = GreedyD1Color()
1527
end
1628

29+
"""
30+
JacPrototypeSparsityDetection(; jac_prototype, alg = GreedyD1Color())
31+
32+
Use a pre-specified `jac_prototype` to compute the matrix colors of the Jacobian.
33+
34+
## Keyword Arguments
35+
36+
- `jac_prototype`: The prototype Jacobian used for computing the matrix colors
37+
- `alg`: The algorithm used for computing the matrix colors
38+
39+
See Also: [SymbolicsSparsityDetection](@ref), [PrecomputedJacobianColorvec](@ref)
40+
"""
1741
Base.@kwdef struct JacPrototypeSparsityDetection{
18-
J, A <: ArrayInterface.ColoringAlgorithm,
19-
} <: AbstractSparsityDetection
42+
J, A <: ArrayInterface.ColoringAlgorithm} <: AbstractSparsityDetection
2043
jac_prototype::J
2144
alg::A = GreedyD1Color()
2245
end
2346

47+
"""
48+
PrecomputedJacobianColorvec(jac_prototype, row_colorvec, col_colorvec)
49+
50+
Use a pre-specified `colorvec` which can be directly used for sparse differentiation. Based
51+
on whether a reverse mode or forward mode or finite differences is used, the corresponding
52+
`row_colorvec` or `col_colorvec` is used. Atmost one of them can be set to `nothing`.
53+
54+
## Arguments
55+
56+
- `jac_prototype`: The prototype Jacobian used for computing structural nonzeros
57+
- `row_colorvec`: The row colorvec of the Jacobian
58+
- `col_colorvec`: The column colorvec of the Jacobian
59+
60+
See Also: [SymbolicsSparsityDetection](@ref), [JacPrototypeSparsityDetection](@ref)
61+
"""
62+
struct PrecomputedJacobianColorvec{J, RC, CC} <: AbstractSparsityDetection
63+
jac_prototype::J
64+
row_colorvec::RC
65+
col_colorvec::CC
66+
end
67+
68+
"""
69+
PrecomputedJacobianColorvec(; jac_prototype, partition_by_rows::Bool = false,
70+
colorvec = missing, row_colorvec = missing, col_colorvec = missing)
71+
72+
Use a pre-specified `colorvec` which can be directly used for sparse differentiation. Based
73+
on whether a reverse mode or forward mode or finite differences is used, the corresponding
74+
`row_colorvec` or `col_colorvec` is used. Atmost one of them can be set to `nothing`.
75+
76+
## Keyword Arguments
77+
78+
- `jac_prototype`: The prototype Jacobian used for computing structural nonzeros
79+
- `partition_by_rows`: Whether to partition the Jacobian by rows or columns (row
80+
partitioning is used for reverse mode AD)
81+
- `colorvec`: The colorvec of the Jacobian. If `partition_by_rows` is `true` then this
82+
is the row colorvec, otherwise it is the column colorvec
83+
- `row_colorvec`: The row colorvec of the Jacobian
84+
- `col_colorvec`: The column colorvec of the Jacobian
85+
86+
See Also: [SymbolicsSparsityDetection](@ref), [JacPrototypeSparsityDetection](@ref)
87+
"""
88+
function PrecomputedJacobianColorvec(; jac_prototype, partition_by_rows::Bool = false,
89+
colorvec = missing, row_colorvec = missing, col_colorvec = missing)
90+
if colorvec === missing
91+
@assert row_colorvec !== missing||col_colorvec !== missing "Either `colorvec` or `row_colorvec` and `col_colorvec` must be specified!"
92+
row_colorvec = row_colorvec === missing ? nothing : row_colorvec
93+
col_colorvec = col_colorvec === missing ? nothing : col_colorvec
94+
return PrecomputedJacobianColorvec(jac_prototype, row_colorvec, col_colorvec)
95+
else
96+
@assert row_colorvec === missing&&col_colorvec === missing "Specifying `colorvec` is incompatible with specifying `row_colorvec` or `col_colorvec`!"
97+
row_colorvec = partition_by_rows ? colorvec : nothing
98+
col_colorvec = partition_by_rows ? nothing : colorvec
99+
return PrecomputedJacobianColorvec(jac_prototype, row_colorvec, col_colorvec)
100+
end
101+
end
102+
103+
function _get_colorvec(alg::PrecomputedJacobianColorvec, ad)
104+
cvec = alg.col_colorvec
105+
@assert cvec!==nothing "`col_colorvec` is nothing, but Forward Mode AD or Finite Differences is being used!"
106+
return cvec
107+
end
108+
109+
function _get_colorvec(alg::PrecomputedJacobianColorvec, ::AbstractReverseMode)
110+
cvec = alg.row_colorvec
111+
@assert cvec!==nothing "`row_colorvec` is nothing, but Reverse Mode AD is being used!"
112+
return cvec
113+
end
114+
115+
# No one should be using this currently
24116
Base.@kwdef struct AutoSparsityDetection{A <: ArrayInterface.ColoringAlgorithm} <:
25117
AbstractSparsityDetection
26118
alg::A = GreedyD1Color()
@@ -41,7 +133,8 @@ Inplace update the matrix `J` with the Jacobian of `f` at `x` using the AD backe
41133
function sparse_jacobian! end
42134

43135
"""
44-
sparse_jacobian_cache(ad::AbstractADType, sd::AbstractSparsityDetection, f, x; fx=nothing)
136+
sparse_jacobian_cache(ad::AbstractADType, sd::AbstractSparsityDetection, f, x;
137+
fx=nothing)
45138
sparse_jacobian_cache(ad::AbstractADType, sd::AbstractSparsityDetection, f!, fx, x)
46139
47140
Takes the underlying AD backend `ad`, sparsity detection algorithm `sd`, function `f`,
@@ -67,7 +160,7 @@ with the same cache to compute the jacobian.
67160
function sparse_jacobian(ad::AbstractADType, sd::AbstractMaybeSparsityDetection, args...;
68161
kwargs...)
69162
cache = sparse_jacobian_cache(ad, sd, args...; kwargs...)
70-
J = __init_𝒥(cache)
163+
J = init_jacobian(cache)
71164
return sparse_jacobian!(J, ad, cache, args...)
72165
end
73166

@@ -80,7 +173,7 @@ Jacobian at every function call
80173
"""
81174
function sparse_jacobian(ad::AbstractADType, cache::AbstractMaybeSparseJacobianCache,
82175
args...)
83-
J = __init_𝒥(cache)
176+
J = init_jacobian(cache)
84177
return sparse_jacobian!(J, ad, cache, args...)
85178
end
86179

@@ -106,7 +199,20 @@ function __gradient end
106199
function __gradient! end
107200
function __jacobian! end
108201

109-
function __init_𝒥 end
202+
"""
203+
init_jacobian(cache::AbstractMaybeSparseJacobianCache)
204+
205+
Initialize the Jacobian based on the cache. Uses sparse jacobians if possible.
206+
207+
!!! note
208+
This function doesn't alias the provided jacobian prototype. It always initializes a
209+
fresh jacobian that can be mutated without any side effects.
210+
"""
211+
function init_jacobian end
212+
213+
# Never thought this was a useful function externally, but I ended up using it in quite a
214+
# few places. Keeping this till I remove uses of those.
215+
const __init_𝒥 = init_jacobian
110216

111217
# Misc Functions
112218
__chunksize(::AutoSparseForwardDiff{C}) where {C} = C
@@ -123,12 +229,12 @@ end
123229
return :(nothing)
124230
end
125231

126-
function __init_𝒥(c::AbstractMaybeSparseJacobianCache)
232+
function init_jacobian(c::AbstractMaybeSparseJacobianCache)
127233
T = promote_type(eltype(c.fx), eltype(c.x))
128-
return __init_𝒥(__getfield(c, Val(:jac_prototype)), T, c.fx, c.x)
234+
return init_jacobian(__getfield(c, Val(:jac_prototype)), T, c.fx, c.x)
129235
end
130-
__init_𝒥(::Nothing, ::Type{T}, fx, x) where {T} = similar(fx, T, length(fx), length(x))
131-
__init_𝒥(J, ::Type{T}, _, _) where {T} = similar(J, T, size(J, 1), size(J, 2))
236+
init_jacobian(::Nothing, ::Type{T}, fx, x) where {T} = similar(fx, T, length(fx), length(x))
237+
init_jacobian(J, ::Type{T}, _, _) where {T} = similar(J, T, size(J, 1), size(J, 2))
132238

133239
__maybe_copy_x(_, x) = x
134240
__maybe_copy_x(_, ::Nothing) = nothing

test/test_sparse_jacobian.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,12 @@ J_true = ForwardDiff.jacobian(fdiff, x);
2626

2727
# SparseDiffTools High-Level API
2828
J_sparsity = Symbolics.jacobian_sparsity(fdiff, similar(x), x);
29+
row_colorvec = SparseDiffTools.matrix_colors(J_sparsity; partition_by_rows = true)
30+
col_colorvec = SparseDiffTools.matrix_colors(J_sparsity; partition_by_rows = false)
2931

30-
SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(jac_prototype = J_sparsity),
31-
SymbolicsSparsityDetection(), NoSparsityDetection()]
32+
SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_sparsity),
33+
SymbolicsSparsityDetection(), NoSparsityDetection(),
34+
PrecomputedJacobianColorvec(; jac_prototype = J_sparsity, row_colorvec, col_colorvec)]
3235

3336
@testset "High-Level API" begin
3437
@testset "Sparsity Detection: $(nameof(typeof(sd)))" for sd in SPARSITY_DETECTION_ALGS
@@ -40,7 +43,7 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(jac_prototype = J_spars
4043
AutoSparseFiniteDiff(), AutoFiniteDiff(), AutoEnzyme(), AutoSparseEnzyme())
4144
@testset "Cache & Reuse" begin
4245
cache = sparse_jacobian_cache(difftype, sd, fdiff, x)
43-
J = SparseDiffTools.__init_𝒥(cache)
46+
J = init_jacobian(cache)
4447

4548
sparse_jacobian!(J, difftype, cache, fdiff, x)
4649

@@ -74,7 +77,7 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(jac_prototype = J_spars
7477
@info "$(nameof(typeof(difftype)))() `sparse_jacobian` (complete) time: $(t₁)s"
7578

7679
cache = sparse_jacobian_cache(difftype, sd, fdiff, x)
77-
J = SparseDiffTools.__init_𝒥(cache)
80+
J = init_jacobian(cache)
7881

7982
sparse_jacobian!(J, difftype, sd, fdiff, x)
8083

@@ -95,7 +98,7 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(jac_prototype = J_spars
9598
cache = sparse_jacobian_cache(difftype, sd, fdiff, y, x)
9699

97100
@testset "Cache & Reuse" begin
98-
J = SparseDiffTools.__init_𝒥(cache)
101+
J = init_jacobian(cache)
99102
sparse_jacobian!(J, difftype, cache, fdiff, y, x)
100103

101104
@test J J_true
@@ -126,7 +129,7 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(jac_prototype = J_spars
126129
t₁ = @elapsed sparse_jacobian(difftype, sd, fdiff, y, x)
127130
@info "$(nameof(typeof(difftype)))() `sparse_jacobian` (complete) time: $(t₁)s"
128131

129-
J = SparseDiffTools.__init_𝒥(cache)
132+
J = init_jacobian(cache)
130133

131134
sparse_jacobian!(J, difftype, sd, fdiff, y, x)
132135

@@ -142,7 +145,7 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(jac_prototype = J_spars
142145
AutoZygote())
143146
y = similar(x)
144147
cache = sparse_jacobian_cache(difftype, sd, fdiff, y, x)
145-
J = SparseDiffTools.__init_𝒥(cache)
148+
J = init_jacobian(cache)
146149

147150
@testset "Cache & Reuse" begin
148151
@test_throws Exception sparse_jacobian!(J, difftype, cache, fdiff, y, x)

0 commit comments

Comments
 (0)