Skip to content

Commit 7c1264f

Browse files
committed
Add PolyesterForwardDiff Support
1 parent 4ad919a commit 7c1264f

File tree

4 files changed

+111
-14
lines changed

4 files changed

+111
-14
lines changed

Project.toml

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SparseDiffTools"
22
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
33
authors = ["Pankaj Mishra <pankajmishra1511@gmail.com>", "Chris Rackauckas <contact@chrisrackauckas.com>"]
4-
version = "2.15.1"
4+
version = "2.16.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -27,17 +27,19 @@ VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f"
2727

2828
[weakdeps]
2929
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
30+
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
3031
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
3132
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3233

3334
[extensions]
3435
SparseDiffToolsEnzymeExt = "Enzyme"
36+
SparseDiffToolsPolyesterForwardDiffExt = "PolyesterForwardDiff"
3537
SparseDiffToolsSymbolicsExt = "Symbolics"
3638
SparseDiffToolsZygoteExt = "Zygote"
3739

3840
[compat]
39-
ADTypes = "0.2.1"
40-
Adapt = "3.0, 4"
41+
ADTypes = "0.2.6"
42+
Adapt = "3, 4"
4143
ArrayInterface = "7.4.2"
4244
Compat = "4"
4345
DataStructures = "0.18"
@@ -47,7 +49,8 @@ ForwardDiff = "0.10"
4749
Graphs = "1"
4850
LinearAlgebra = "<0.0.1, 1"
4951
PackageExtensionCompat = "1"
50-
Random = "<0.0.1, 1"
52+
PolyesterForwardDiff = "0.1.1"
53+
Random = "1.6"
5154
Reexport = "1"
5255
SciMLOperators = "0.3.7"
5356
Setfield = "1"
@@ -67,6 +70,7 @@ BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
6770
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
6871
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
6972
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
73+
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
7074
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
7175
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
7276
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
@@ -75,4 +79,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7579
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
7680

7781
[targets]
78-
test = ["Test", "BandedMatrices", "BlockBandedMatrices", "Enzyme", "IterativeSolvers", "Pkg", "Random", "SafeTestsets", "Symbolics", "Zygote", "StaticArrays"]
82+
test = ["Test", "BandedMatrices", "BlockBandedMatrices", "Enzyme", "IterativeSolvers", "Pkg", "Random", "SafeTestsets", "Symbolics", "Zygote", "StaticArrays", "PolyesterForwardDiff"]
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
module SparseDiffToolsPolyesterForwardDiffExt
2+
3+
using ADTypes, SparseDiffTools, PolyesterForwardDiff
4+
import ForwardDiff
5+
import SparseDiffTools: AbstractMaybeSparseJacobianCache, AbstractMaybeSparsityDetection,
6+
ForwardColorJacCache, NoMatrixColoring, sparse_jacobian_cache, sparse_jacobian!,
7+
sparse_jacobian_static_array, __standard_tag, __chunksize
8+
9+
struct PolyesterForwardDiffJacobianCache{CO, CA, J, FX, X} <:
10+
AbstractMaybeSparseJacobianCache
11+
coloring::CO
12+
cache::CA
13+
jac_prototype::J
14+
fx::FX
15+
x::X
16+
end
17+
18+
function sparse_jacobian_cache(ad::Union{AutoSparsePolyesterForwardDiff,
19+
AutoPolyesterForwardDiff}, sd::AbstractMaybeSparsityDetection, f::F, x;
20+
fx = nothing) where {F}
21+
coloring_result = sd(ad, f, x)
22+
fx = fx === nothing ? similar(f(x)) : fx
23+
if coloring_result isa NoMatrixColoring
24+
cache = __chunksize(ad, x)
25+
jac_prototype = nothing
26+
else
27+
@warn """Currently PolyesterForwardDiff does not support sparsity detection
28+
natively. Falling back to using ForwardDiff.jl""" maxlog=1
29+
tag = __standard_tag(nothing, x)
30+
# Colored ForwardDiff passes `tag` directly into Dual so we need the `typeof`
31+
cache = ForwardColorJacCache(f, x, __chunksize(ad); coloring_result.colorvec,
32+
dx = fx, sparsity = coloring_result.jacobian_sparsity, tag = typeof(tag))
33+
jac_prototype = coloring_result.jacobian_sparsity
34+
end
35+
return PolyesterForwardDiffJacobianCache(coloring_result, cache, jac_prototype, fx, x)
36+
end
37+
38+
function sparse_jacobian_cache(ad::Union{AutoSparsePolyesterForwardDiff,
39+
AutoPolyesterForwardDiff}, sd::AbstractMaybeSparsityDetection, f!::F, fx,
40+
x) where {F}
41+
coloring_result = sd(ad, f!, fx, x)
42+
if coloring_result isa NoMatrixColoring
43+
cache = __chunksize(ad, x)
44+
jac_prototype = nothing
45+
else
46+
@warn """Currently PolyesterForwardDiff does not support sparsity detection
47+
natively. Falling back to using ForwardDiff.jl""" maxlog=1
48+
tag = __standard_tag(nothing, x)
49+
# Colored ForwardDiff passes `tag` directly into Dual so we need the `typeof`
50+
cache = ForwardColorJacCache(f!, x, __chunksize(ad); coloring_result.colorvec,
51+
dx = fx, sparsity = coloring_result.jacobian_sparsity, tag = typeof(tag))
52+
jac_prototype = coloring_result.jacobian_sparsity
53+
end
54+
return PolyesterForwardDiffJacobianCache(coloring_result, cache, jac_prototype, fx, x)
55+
end
56+
57+
function sparse_jacobian!(J::AbstractMatrix, _, cache::PolyesterForwardDiffJacobianCache,
58+
f::F, x) where {F}
59+
if cache.cache isa ForwardColorJacCache
60+
forwarddiff_color_jacobian(J, f, x, cache.cache) # Use Sparse ForwardDiff
61+
else
62+
PolyesterForwardDiff.threaded_jacobian!(f, J, x, cache.cache) # Don't try to exploit sparsity
63+
end
64+
return J
65+
end
66+
67+
function sparse_jacobian!(J::AbstractMatrix, _, cache::PolyesterForwardDiffJacobianCache,
68+
f!::F, fx, x) where {F}
69+
if cache.cache isa ForwardColorJacCache
70+
forwarddiff_color_jacobian!(J, f!, x, cache.cache) # Use Sparse ForwardDiff
71+
else
72+
PolyesterForwardDiff.threaded_jacobian!(f!, fx, J, x, cache.cache) # Don't try to exploit sparsity
73+
end
74+
return J
75+
end
76+
77+
end

src/highlevel/common.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,8 @@ function init_jacobian end
269269
const __init_𝒥 = init_jacobian
270270

271271
# Misc Functions
272-
function __chunksize(::Union{AutoSparseForwardDiff{C}, AutoForwardDiff{C}}, x) where {C}
272+
function __chunksize(::Union{AutoSparseForwardDiff{C}, AutoForwardDiff{C},
273+
AutoSparsePolyesterForwardDiff{C}, AutoPolyesterForwardDiff{C}}, x) where {C}
273274
C isa ForwardDiff.Chunk && return C
274275
return __chunksize(Val(C), x)
275276
end
@@ -285,7 +286,8 @@ end
285286
__chunksize(x) = ForwardDiff.Chunk(x)
286287
__chunksize(x::StaticArray) = ForwardDiff.Chunk{ForwardDiff.pickchunksize(prod(Size(x)))}()
287288

288-
function __chunksize(::Union{AutoSparseForwardDiff{C}, AutoForwardDiff{C}}) where {C}
289+
function __chunksize(::Union{AutoSparseForwardDiff{C}, AutoForwardDiff{C},
290+
AutoSparsePolyesterForwardDiff{C}, AutoPolyesterForwardDiff{C}}) where {C}
289291
C === nothing && return nothing
290292
C isa Integer && !(C isa Bool) && return C 0 ? nothing : Val(C)
291293
return nothing
@@ -347,4 +349,5 @@ end
347349
@inline __backend(::Union{AutoEnzyme, AutoSparseEnzyme}) = :Enzyme
348350
@inline __backend(::Union{AutoZygote, AutoSparseZygote}) = :Zygote
349351
@inline __backend(::Union{AutoForwardDiff, AutoSparseForwardDiff}) = :ForwardDiff
352+
@inline __backend(::Union{AutoPolyesterForwardDiff, AutoSparsePolyesterForwardDiff}) = :PolyesterForwardDiff
350353
@inline __backend(::Union{AutoFiniteDiff, AutoSparseFiniteDiff}) = :FiniteDiff

test/test_sparse_jacobian.jl

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## Sparse Jacobian tests
2-
using SparseDiffTools,
3-
Symbolics, ForwardDiff, LinearAlgebra, SparseArrays, Zygote, Enzyme, Test, StaticArrays
2+
using SparseDiffTools, PolyesterForwardDiff, Symbolics, ForwardDiff, LinearAlgebra,
3+
SparseArrays, Zygote, Enzyme, Test, StaticArrays
44

55
@views function fdiff(y, x) # in-place
66
L = length(x)
@@ -42,7 +42,12 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
4242
AutoZygote(), AutoSparseForwardDiff(), AutoForwardDiff(),
4343
AutoSparseForwardDiff(; chunksize = 0), AutoForwardDiff(; chunksize = 0),
4444
AutoSparseForwardDiff(; chunksize = 4), AutoForwardDiff(; chunksize = 4),
45-
AutoSparseFiniteDiff(), AutoFiniteDiff(), AutoEnzyme(), AutoSparseEnzyme())
45+
AutoSparsePolyesterForwardDiff(), AutoPolyesterForwardDiff(),
46+
AutoSparsePolyesterForwardDiff(; chunksize = 0),
47+
AutoPolyesterForwardDiff(; chunksize = 0),
48+
AutoSparsePolyesterForwardDiff(; chunksize = 4),
49+
AutoPolyesterForwardDiff(; chunksize = 4), AutoSparseFiniteDiff(),
50+
AutoFiniteDiff(), AutoEnzyme(), AutoSparseEnzyme())
4651
@testset "Cache & Reuse" begin
4752
cache = sparse_jacobian_cache(difftype, sd, fdiff, x)
4853
J = init_jacobian(cache)
@@ -59,7 +64,9 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
5964

6065
@test J J_true
6166

62-
if !(difftype isa AutoSparseForwardDiff || difftype isa AutoForwardDiff)
67+
if !(difftype isa AutoSparseForwardDiff || difftype isa AutoForwardDiff ||
68+
difftype isa AutoSparsePolyesterForwardDiff ||
69+
difftype isa AutoPolyesterForwardDiff)
6370
@inferred sparse_jacobian(difftype, cache, fdiff, x)
6471
end
6572

@@ -71,7 +78,9 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
7178
J = sparse_jacobian(difftype, sd, fdiff, x)
7279

7380
@test J J_true
74-
if !(difftype isa AutoSparseForwardDiff || difftype isa AutoForwardDiff)
81+
if !(difftype isa AutoSparseForwardDiff || difftype isa AutoForwardDiff ||
82+
difftype isa AutoSparsePolyesterForwardDiff ||
83+
difftype isa AutoPolyesterForwardDiff)
7584
@inferred sparse_jacobian(difftype, sd, fdiff, x)
7685
end
7786

@@ -114,7 +123,9 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
114123
J = sparse_jacobian(difftype, cache, fdiff, y, x)
115124

116125
@test J J_true
117-
if !(difftype isa AutoSparseForwardDiff || difftype isa AutoForwardDiff)
126+
if !(difftype isa AutoSparseForwardDiff || difftype isa AutoForwardDiff ||
127+
difftype isa AutoSparsePolyesterForwardDiff ||
128+
difftype isa AutoPolyesterForwardDiff)
118129
@inferred sparse_jacobian(difftype, cache, fdiff, y, x)
119130
end
120131

@@ -126,7 +137,9 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
126137
J = sparse_jacobian(difftype, sd, fdiff, y, x)
127138

128139
@test J J_true
129-
if !(difftype isa AutoSparseForwardDiff || difftype isa AutoForwardDiff)
140+
if !(difftype isa AutoSparseForwardDiff || difftype isa AutoForwardDiff ||
141+
difftype isa AutoSparsePolyesterForwardDiff ||
142+
difftype isa AutoPolyesterForwardDiff)
130143
@inferred sparse_jacobian(difftype, sd, fdiff, y, x)
131144
end
132145

0 commit comments

Comments
 (0)