Skip to content

Commit 88aaa61

Browse files
committed
Fix more dispatches
1 parent 20cc8e7 commit 88aaa61

File tree

4 files changed

+48
-26
lines changed

4 files changed

+48
-26
lines changed

src/SparseDiffTools.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ include("highlevel/common.jl")
5454
include("highlevel/coloring.jl")
5555
include("highlevel/forward_mode.jl")
5656
include("highlevel/reverse_mode.jl")
57+
include("highlevel/forward_or_reverse_mode.jl")
5758
include("highlevel/finite_diff.jl")
5859

5960
Base.@pure __parameterless_type(T) = Base.typename(T).wrapper

src/highlevel/common.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,12 @@ If `fx` is not specified, it will be computed by calling `f(x)`.
171171
A cache for computing the Jacobian of type `AbstractMaybeSparseJacobianCache`.
172172
"""
173173
function sparse_jacobian_cache(
174-
ad::AbstractADType, sd::AbstractSparsityDetection, f, x; fx = nothing)
174+
ad::AbstractADType, sd::AbstractMaybeSparsityDetection, f, x; fx = nothing)
175175
return sparse_jacobian_cache_aux(mode(ad), ad, sd, f, x; fx)
176176
end
177177

178178
function sparse_jacobian_cache(
179-
ad::AbstractADType, sd::AbstractSparsityDetection, f!, x, fx)
179+
ad::AbstractADType, sd::AbstractMaybeSparsityDetection, f!, x, fx)
180180
return sparse_jacobian_cache_aux(mode(ad), ad, sd, f!, x, fx)
181181
end
182182

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
function sparse_jacobian_cache_aux(::ForwardOrReverseMode, ad::AbstractADType,
2+
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
3+
if ad isa AutoEnzyme
4+
return sparse_jacobian_cache_aux(ReverseMode(), ad, sd, f, x; fx)
5+
elseif ad isa AutoDiffractor
6+
return sparse_jacobian_cache_aux(ForwardMode(), ad, sd, f, x; fx)
7+
else
8+
error("Unknown mixed mode AD")
9+
end
10+
end
11+
12+
function sparse_jacobian_cache_aux(::ForwardOrReverseMode, ad::AbstractADType,
13+
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F}
14+
if ad isa AutoEnzyme
15+
return sparse_jacobian_cache_aux(ReverseMode(), ad, sd, f!, fx, x)
16+
elseif ad isa AutoDiffractor
17+
return sparse_jacobian_cache_aux(ForwardMode(), ad, sd, f!, fx, x)
18+
else
19+
error("Unknown mixed mode AD")
20+
end
21+
end

test/test_sparse_jacobian.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
## Sparse Jacobian tests
2-
using SparseDiffTools,
3-
Symbolics, ForwardDiff, LinearAlgebra, SparseArrays, Zygote, Enzyme, Test,
2+
using ADTypes, SparseDiffTools,
3+
Symbolics, ForwardDiff, PolyesterForwardDiff, LinearAlgebra, SparseArrays, Zygote,
4+
Enzyme, Test,
45
StaticArrays
6+
using ADTypes: dense_ad
57

6-
@static if VERSION v"1.9"
7-
using PolyesterForwardDiff
8-
end
9-
10-
function __chunksize(::Union{AutoSparse{<:AutoForwardDiff}{C}, AutoForwardDiff{C},
11-
AutoSparse{<:AutoPolyesterForwardDiff}{C}, AutoPolyesterForwardDiff{C}}) where {C}
8+
function __chunksize(::Union{
9+
AutoSparse{<:AutoForwardDiff{C}}, AutoForwardDiff{C},
10+
AutoSparse{<:AutoPolyesterForwardDiff{C}}, AutoPolyesterForwardDiff{C}
11+
}) where {C}
1212
return C
1313
end
1414

1515
function __isinferrable(difftype)
16-
return !(difftype isa AutoSparse{<:AutoForwardDiff} || difftype isa AutoForwardDiff ||
16+
return !(difftype isa AutoSparse{<:AutoForwardDiff} ||
17+
difftype isa AutoForwardDiff ||
1718
difftype isa AutoSparse{<:AutoPolyesterForwardDiff} ||
1819
difftype isa AutoPolyesterForwardDiff) ||
1920
(__chunksize(difftype) isa Int && __chunksize(difftype) > 0)
@@ -51,24 +52,23 @@ SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_spa
5152
PrecomputedJacobianColorvec(; jac_prototype = J_sparsity, row_colorvec, col_colorvec)]
5253

5354
@testset "High-Level API" begin
54-
@testset "Sparsity Detection: $(nameof(typeof(sd)))" for sd in SPARSITY_DETECTION_ALGS
55+
@testset "Sparsity Detection: $(nameof(typeof(sd))) - $(isa(ad, AutoSparse) ? $(nameof(typeof(dense_ad(ad)))) : "")" for sd in SPARSITY_DETECTION_ALGS
5556
@info "Sparsity Detection: $(nameof(typeof(sd)))"
5657
@info "Out of Place Function"
5758

58-
DIFFTYPES = [AutoSparse(AutoZygote()), AutoZygote(), AutoSparse(AutoForwardDiff()),
59-
AutoForwardDiff(), AutoSparse(AutoForwardDiff(; chunksize = 0)),
60-
AutoForwardDiff(; chunksize = 0), AutoSparse(AutoForwardDiff(; chunksize = 4)),
61-
AutoForwardDiff(; chunksize = 4), AutoSparse(AutoFiniteDiff()), AutoFiniteDiff(),
62-
AutoEnzyme(), AutoSparse(AutoEnzyme())]
63-
64-
if VERSION v"1.9"
65-
append!(DIFFTYPES,
66-
[AutoSparse(AutoPolyesterForwardDiff()), AutoPolyesterForwardDiff(),
67-
AutoSparse(AutoPolyesterForwardDiff(; chunksize = 0)),
68-
AutoPolyesterForwardDiff(; chunksize = 0),
69-
AutoSparse(AutoPolyesterForwardDiff(; chunksize = 4)),
70-
AutoPolyesterForwardDiff(; chunksize = 4)])
71-
end
59+
DIFFTYPES = [
60+
AutoSparse(AutoZygote()), AutoZygote(),
61+
AutoSparse(AutoForwardDiff()), AutoForwardDiff(),
62+
AutoSparse(AutoForwardDiff(; chunksize = 0)), AutoForwardDiff(; chunksize = 0),
63+
AutoSparse(AutoForwardDiff(; chunksize = 4)), AutoForwardDiff(; chunksize = 4),
64+
AutoSparse(AutoFiniteDiff()), AutoFiniteDiff(),
65+
AutoEnzyme(), AutoSparse(AutoEnzyme()),
66+
AutoSparse(AutoPolyesterForwardDiff()), AutoPolyesterForwardDiff(),
67+
AutoSparse(AutoPolyesterForwardDiff(; chunksize = 0)),
68+
AutoPolyesterForwardDiff(; chunksize = 0),
69+
AutoSparse(AutoPolyesterForwardDiff(; chunksize = 4)),
70+
AutoPolyesterForwardDiff(; chunksize = 4)
71+
]
7272

7373
@testset "sparse_jacobian $(nameof(typeof(difftype))): Out of Place" for difftype in DIFFTYPES
7474
@testset "Cache & Reuse" begin

0 commit comments

Comments
 (0)