Skip to content

Commit 3b42466

Browse files
Merge pull request #281 from avik-pal/ap/polyester-forward-diff
Add PolyesterForwardDiff Support
2 parents bbd1f6e + 9ccbe30 commit 3b42466

File tree

7 files changed

+130
-18
lines changed

7 files changed

+130
-18
lines changed

Project.toml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SparseDiffTools"
22
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
33
authors = ["Pankaj Mishra <pankajmishra1511@gmail.com>", "Chris Rackauckas <contact@chrisrackauckas.com>"]
4-
version = "2.15.1"
4+
version = "2.16.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -27,17 +27,19 @@ VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f"
2727

2828
[weakdeps]
2929
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
30+
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
3031
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
3132
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3233

3334
[extensions]
3435
SparseDiffToolsEnzymeExt = "Enzyme"
36+
SparseDiffToolsPolyesterForwardDiffExt = "PolyesterForwardDiff"
3537
SparseDiffToolsSymbolicsExt = "Symbolics"
3638
SparseDiffToolsZygoteExt = "Zygote"
3739

3840
[compat]
39-
ADTypes = "0.2.1"
40-
Adapt = "3.0, 4"
41+
ADTypes = "0.2.6"
42+
Adapt = "3, 4"
4143
ArrayInterface = "7.4.2"
4244
Compat = "4"
4345
DataStructures = "0.18"
@@ -47,7 +49,8 @@ ForwardDiff = "0.10"
4749
Graphs = "1"
4850
LinearAlgebra = "<0.0.1, 1"
4951
PackageExtensionCompat = "1"
50-
Random = "<0.0.1, 1"
52+
PolyesterForwardDiff = "0.1.1"
53+
Random = "1.6"
5154
Reexport = "1"
5255
SciMLOperators = "0.3.7"
5356
Setfield = "1"
@@ -67,6 +70,7 @@ BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
6770
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
6871
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
6972
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
73+
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
7074
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
7175
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
7276
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
module SparseDiffToolsPolyesterForwardDiffExt
2+
3+
using ADTypes, SparseDiffTools, PolyesterForwardDiff
4+
import ForwardDiff
5+
import SparseDiffTools: AbstractMaybeSparseJacobianCache, AbstractMaybeSparsityDetection,
6+
ForwardColorJacCache, NoMatrixColoring, sparse_jacobian_cache, sparse_jacobian!,
7+
sparse_jacobian_static_array, __standard_tag, __chunksize
8+
9+
struct PolyesterForwardDiffJacobianCache{CO, CA, J, FX, X} <:
10+
AbstractMaybeSparseJacobianCache
11+
coloring::CO
12+
cache::CA
13+
jac_prototype::J
14+
fx::FX
15+
x::X
16+
end
17+
18+
function sparse_jacobian_cache(ad::Union{AutoSparsePolyesterForwardDiff,
19+
AutoPolyesterForwardDiff}, sd::AbstractMaybeSparsityDetection, f::F, x;
20+
fx = nothing) where {F}
21+
coloring_result = sd(ad, f, x)
22+
fx = fx === nothing ? similar(f(x)) : fx
23+
if coloring_result isa NoMatrixColoring
24+
cache = __chunksize(ad, x)
25+
jac_prototype = nothing
26+
else
27+
@warn """Currently PolyesterForwardDiff does not support sparsity detection
28+
natively. Falling back to using ForwardDiff.jl""" maxlog=1
29+
tag = __standard_tag(nothing, x)
30+
# Colored ForwardDiff passes `tag` directly into Dual so we need the `typeof`
31+
cache = ForwardColorJacCache(f, x, __chunksize(ad); coloring_result.colorvec,
32+
dx = fx, sparsity = coloring_result.jacobian_sparsity, tag = typeof(tag))
33+
jac_prototype = coloring_result.jacobian_sparsity
34+
end
35+
return PolyesterForwardDiffJacobianCache(coloring_result, cache, jac_prototype, fx, x)
36+
end
37+
38+
function sparse_jacobian_cache(ad::Union{AutoSparsePolyesterForwardDiff,
39+
AutoPolyesterForwardDiff}, sd::AbstractMaybeSparsityDetection, f!::F, fx,
40+
x) where {F}
41+
coloring_result = sd(ad, f!, fx, x)
42+
if coloring_result isa NoMatrixColoring
43+
cache = __chunksize(ad, x)
44+
jac_prototype = nothing
45+
else
46+
@warn """Currently PolyesterForwardDiff does not support sparsity detection
47+
natively. Falling back to using ForwardDiff.jl""" maxlog=1
48+
tag = __standard_tag(nothing, x)
49+
# Colored ForwardDiff passes `tag` directly into Dual so we need the `typeof`
50+
cache = ForwardColorJacCache(f!, x, __chunksize(ad); coloring_result.colorvec,
51+
dx = fx, sparsity = coloring_result.jacobian_sparsity, tag = typeof(tag))
52+
jac_prototype = coloring_result.jacobian_sparsity
53+
end
54+
return PolyesterForwardDiffJacobianCache(coloring_result, cache, jac_prototype, fx, x)
55+
end
56+
57+
function sparse_jacobian!(J::AbstractMatrix, _, cache::PolyesterForwardDiffJacobianCache,
58+
f::F, x) where {F}
59+
if cache.cache isa ForwardColorJacCache
60+
forwarddiff_color_jacobian(J, f, x, cache.cache) # Use Sparse ForwardDiff
61+
else
62+
PolyesterForwardDiff.threaded_jacobian!(f, J, x, cache.cache) # Don't try to exploit sparsity
63+
end
64+
return J
65+
end
66+
67+
function sparse_jacobian!(J::AbstractMatrix, _, cache::PolyesterForwardDiffJacobianCache,
68+
f!::F, fx, x) where {F}
69+
if cache.cache isa ForwardColorJacCache
70+
forwarddiff_color_jacobian!(J, f!, x, cache.cache) # Use Sparse ForwardDiff
71+
else
72+
PolyesterForwardDiff.threaded_jacobian!(f!, fx, J, x, cache.cache) # Don't try to exploit sparsity
73+
end
74+
return J
75+
end
76+
77+
end

src/highlevel/common.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,8 @@ function init_jacobian end
269269
const __init_𝒥 = init_jacobian
270270

271271
# Misc Functions
272-
function __chunksize(::Union{AutoSparseForwardDiff{C}, AutoForwardDiff{C}}, x) where {C}
272+
function __chunksize(::Union{AutoSparseForwardDiff{C}, AutoForwardDiff{C},
273+
AutoSparsePolyesterForwardDiff{C}, AutoPolyesterForwardDiff{C}}, x) where {C}
273274
C isa ForwardDiff.Chunk && return C
274275
return __chunksize(Val(C), x)
275276
end
@@ -285,7 +286,8 @@ end
285286
__chunksize(x) = ForwardDiff.Chunk(x)
286287
__chunksize(x::StaticArray) = ForwardDiff.Chunk{ForwardDiff.pickchunksize(prod(Size(x)))}()
287288

288-
function __chunksize(::Union{AutoSparseForwardDiff{C}, AutoForwardDiff{C}}) where {C}
289+
function __chunksize(::Union{AutoSparseForwardDiff{C}, AutoForwardDiff{C},
290+
AutoSparsePolyesterForwardDiff{C}, AutoPolyesterForwardDiff{C}}) where {C}
289291
C === nothing && return nothing
290292
C isa Integer && !(C isa Bool) && return C 0 ? nothing : Val(C)
291293
return nothing
@@ -347,4 +349,5 @@ end
347349
@inline __backend(::Union{AutoEnzyme, AutoSparseEnzyme}) = :Enzyme
348350
@inline __backend(::Union{AutoZygote, AutoSparseZygote}) = :Zygote
349351
@inline __backend(::Union{AutoForwardDiff, AutoSparseForwardDiff}) = :ForwardDiff
352+
@inline __backend(::Union{AutoPolyesterForwardDiff, AutoSparsePolyesterForwardDiff}) = :PolyesterForwardDiff
350353
@inline __backend(::Union{AutoFiniteDiff, AutoSparseFiniteDiff}) = :FiniteDiff

test/1.9specific/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[deps]
2+
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
3+
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"

test/allocs/Project.toml

Lines changed: 0 additions & 2 deletions
This file was deleted.

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ if GROUP == "Core" || GROUP == "All"
4242
end
4343

4444
if GROUP == "InterfaceI" || GROUP == "All"
45-
VERSION v"1.9" && activate_env("allocs")
45+
VERSION v"1.9" && activate_env("1.9specific")
4646
@time @safetestset "Jac Vecs and Hes Vecs" begin
4747
include("test_jaches_products.jl")
4848
end

test/test_sparse_jacobian.jl

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,22 @@
22
using SparseDiffTools,
33
Symbolics, ForwardDiff, LinearAlgebra, SparseArrays, Zygote, Enzyme, Test, StaticArrays
44

5+
@static if VERSION v"1.9"
6+
using PolyesterForwardDiff
7+
end
8+
9+
function __chunksize(::Union{AutoSparseForwardDiff{C}, AutoForwardDiff{C},
10+
AutoSparsePolyesterForwardDiff{C}, AutoPolyesterForwardDiff{C}}) where {C}
11+
return C
12+
end
13+
14+
function __isinferrable(difftype)
15+
return !(difftype isa AutoSparseForwardDiff || difftype isa AutoForwardDiff ||
16+
difftype isa AutoSparsePolyesterForwardDiff ||
17+
difftype isa AutoPolyesterForwardDiff) ||
18+
(__chunksize(difftype) isa Int && __chunksize(difftype) > 0)
19+
end
20+
521
@views function fdiff(y, x) # in-place
622
L = length(x)
723
y[2:(L - 1)] .= x[1:(L - 2)] .- 2 .* x[2:(L - 1)] .+ x[3:L]
@@ -38,11 +54,22 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
3854
@info "Sparsity Detection: $(nameof(typeof(sd)))"
3955
@info "Out of Place Function"
4056

41-
@testset "sparse_jacobian $(nameof(typeof(difftype))): Out of Place" for difftype in (AutoSparseZygote(),
42-
AutoZygote(), AutoSparseForwardDiff(), AutoForwardDiff(),
43-
AutoSparseForwardDiff(; chunksize = 0), AutoForwardDiff(; chunksize = 0),
44-
AutoSparseForwardDiff(; chunksize = 4), AutoForwardDiff(; chunksize = 4),
45-
AutoSparseFiniteDiff(), AutoFiniteDiff(), AutoEnzyme(), AutoSparseEnzyme())
57+
DIFFTYPES = [AutoSparseZygote(), AutoZygote(), AutoSparseForwardDiff(),
58+
AutoForwardDiff(), AutoSparseForwardDiff(; chunksize = 0),
59+
AutoForwardDiff(; chunksize = 0), AutoSparseForwardDiff(; chunksize = 4),
60+
AutoForwardDiff(; chunksize = 4), AutoSparseFiniteDiff(), AutoFiniteDiff(),
61+
AutoEnzyme(), AutoSparseEnzyme()]
62+
63+
if VERSION v"1.9"
64+
append!(DIFFTYPES,
65+
[AutoSparsePolyesterForwardDiff(), AutoPolyesterForwardDiff(),
66+
AutoSparsePolyesterForwardDiff(; chunksize = 0),
67+
AutoPolyesterForwardDiff(; chunksize = 0),
68+
AutoSparsePolyesterForwardDiff(; chunksize = 4),
69+
AutoPolyesterForwardDiff(; chunksize = 4)])
70+
end
71+
72+
@testset "sparse_jacobian $(nameof(typeof(difftype))): Out of Place" for difftype in DIFFTYPES
4673
@testset "Cache & Reuse" begin
4774
cache = sparse_jacobian_cache(difftype, sd, fdiff, x)
4875
J = init_jacobian(cache)
@@ -59,7 +86,7 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
5986

6087
@test J J_true
6188

62-
if !(difftype isa AutoSparseForwardDiff || difftype isa AutoForwardDiff)
89+
if __isinferrable(difftype)
6390
@inferred sparse_jacobian(difftype, cache, fdiff, x)
6491
end
6592

@@ -71,7 +98,7 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
7198
J = sparse_jacobian(difftype, sd, fdiff, x)
7299

73100
@test J J_true
74-
if !(difftype isa AutoSparseForwardDiff || difftype isa AutoForwardDiff)
101+
if __isinferrable(difftype)
75102
@inferred sparse_jacobian(difftype, sd, fdiff, x)
76103
end
77104

@@ -114,7 +141,7 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
114141
J = sparse_jacobian(difftype, cache, fdiff, y, x)
115142

116143
@test J J_true
117-
if !(difftype isa AutoSparseForwardDiff || difftype isa AutoForwardDiff)
144+
if __isinferrable(difftype)
118145
@inferred sparse_jacobian(difftype, cache, fdiff, y, x)
119146
end
120147

@@ -126,7 +153,7 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
126153
J = sparse_jacobian(difftype, sd, fdiff, y, x)
127154

128155
@test J J_true
129-
if !(difftype isa AutoSparseForwardDiff || difftype isa AutoForwardDiff)
156+
if __isinferrable(difftype)
130157
@inferred sparse_jacobian(difftype, sd, fdiff, y, x)
131158
end
132159

0 commit comments

Comments
 (0)