Skip to content

Commit d6fae4e

Browse files
committed
Enable tag checking
1 parent a7666fa commit d6fae4e

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

src/highlevel/forward_mode.jl

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,26 @@ __getfield(c::ForwardDiffJacobianCache, ::Val{:jac_prototype}) = c.jac_prototype
1010

1111
struct SparseDiffToolsTag end
1212

13+
function ForwardDiff.checktag(::Type{<:ForwardDiff.Tag{<:SparseDiffToolsTag, <:T}}, f::F,
14+
x::AbstractArray{T}) where {T, F}
15+
return true
16+
end
17+
1318
__standard_tag(::Nothing, x) = ForwardDiff.Tag(SparseDiffToolsTag(), eltype(x))
14-
__standard_tag(tag, _) = tag
19+
__standard_tag(tag, x) = ForwardDiff.Tag(tag, eltype(x))
1520

1621
function sparse_jacobian_cache(ad::Union{AutoSparseForwardDiff, AutoForwardDiff},
1722
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
1823
coloring_result = sd(ad, f, x)
1924
fx = fx === nothing ? similar(f(x)) : fx
25+
tag = __standard_tag(ad.tag, x)
2026
if coloring_result isa NoMatrixColoring
21-
cache = ForwardDiff.JacobianConfig(f, x, __chunksize(ad, x),
22-
__standard_tag(ad.tag, x))
27+
cache = ForwardDiff.JacobianConfig(f, x, __chunksize(ad, x), tag)
2328
jac_prototype = nothing
2429
else
30+
# Colored ForwardDiff passes `tag` directly into Dual so we need the `typeof`
2531
cache = ForwardColorJacCache(f, x, __chunksize(ad); coloring_result.colorvec,
26-
dx = fx, sparsity = coloring_result.jacobian_sparsity, ad.tag)
32+
dx = fx, sparsity = coloring_result.jacobian_sparsity, tag = typeof(tag))
2733
jac_prototype = coloring_result.jacobian_sparsity
2834
end
2935
return ForwardDiffJacobianCache(coloring_result, cache, jac_prototype, fx, x)
@@ -32,13 +38,14 @@ end
3238
function sparse_jacobian_cache(ad::Union{AutoSparseForwardDiff, AutoForwardDiff},
3339
sd::AbstractMaybeSparsityDetection, f!::F, fx, x) where {F}
3440
coloring_result = sd(ad, f!, fx, x)
41+
tag = __standard_tag(ad.tag, x)
3542
if coloring_result isa NoMatrixColoring
36-
cache = ForwardDiff.JacobianConfig(f!, fx, x, __chunksize(ad, x),
37-
__standard_tag(ad.tag, x))
43+
cache = ForwardDiff.JacobianConfig(f!, fx, x, __chunksize(ad, x), tag)
3844
jac_prototype = nothing
3945
else
46+
# Colored ForwardDiff passes `tag` directly into Dual so we need the `typeof`
4047
cache = ForwardColorJacCache(f!, x, __chunksize(ad); coloring_result.colorvec,
41-
dx = fx, sparsity = coloring_result.jacobian_sparsity, ad.tag)
48+
dx = fx, sparsity = coloring_result.jacobian_sparsity, tag = typeof(tag))
4249
jac_prototype = coloring_result.jacobian_sparsity
4350
end
4451
return ForwardDiffJacobianCache(coloring_result, cache, jac_prototype, fx, x)
@@ -49,8 +56,7 @@ function sparse_jacobian!(J::AbstractMatrix, _, cache::ForwardDiffJacobianCache,
4956
if cache.cache isa ForwardColorJacCache
5057
forwarddiff_color_jacobian(J, f, x, cache.cache) # Use Sparse ForwardDiff
5158
else
52-
# Disable tag checking since we set the tag to our custom tag
53-
ForwardDiff.jacobian!(J, f, x, cache.cache, Val(false)) # Don't try to exploit sparsity
59+
ForwardDiff.jacobian!(J, f, x, cache.cache) # Don't try to exploit sparsity
5460
end
5561
return J
5662
end
@@ -60,8 +66,7 @@ function sparse_jacobian!(J::AbstractMatrix, _, cache::ForwardDiffJacobianCache,
6066
if cache.cache isa ForwardColorJacCache
6167
forwarddiff_color_jacobian!(J, f!, x, cache.cache) # Use Sparse ForwardDiff
6268
else
63-
# Disable tag checking since we set the tag to our custom tag
64-
ForwardDiff.jacobian!(J, f!, fx, x, cache.cache, Val(false)) # Don't try to exploit sparsity
69+
ForwardDiff.jacobian!(J, f!, fx, x, cache.cache) # Don't try to exploit sparsity
6570
end
6671
return J
6772
end

0 commit comments

Comments
 (0)