Skip to content

Commit 63fc800

Browse files
committed
Make things kind-of type stable when chunksize is not specified
1 parent c8c61bf commit 63fc800

File tree

4 files changed

+21
-7
lines changed

4 files changed

+21
-7
lines changed

src/highlevel/common.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,8 @@ function sparse_jacobian(ad::AbstractADType, sd::AbstractMaybeSparsityDetection,
186186
kwargs...)
187187
cache = sparse_jacobian_cache(ad, sd, args...; kwargs...)
188188
J = init_jacobian(cache)
189-
return sparse_jacobian!(J, ad, cache, args...)
189+
sparse_jacobian!(J, ad, cache, args...)
190+
return J
190191
end
191192

192193
"""
@@ -199,7 +200,8 @@ Jacobian at every function call
199200
function sparse_jacobian(ad::AbstractADType, cache::AbstractMaybeSparseJacobianCache,
200201
args...)
201202
J = init_jacobian(cache)
202-
return sparse_jacobian!(J, ad, cache, args...)
203+
sparse_jacobian!(J, ad, cache, args...)
204+
return J
203205
end
204206

205207
"""
@@ -216,7 +218,8 @@ with the same cache to compute the jacobian.
216218
function sparse_jacobian!(J::AbstractMatrix, ad::AbstractADType,
217219
sd::AbstractMaybeSparsityDetection, args...; kwargs...)
218220
cache = sparse_jacobian_cache(ad, sd, args...; kwargs...)
219-
return sparse_jacobian!(J, ad, cache, args...)
221+
sparse_jacobian!(J, ad, cache, args...)
222+
return J
220223
end
221224

222225
## Internal

src/highlevel/finite_diff.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ struct FiniteDiffJacobianCache{CO, CA, J, FX, X} <: AbstractMaybeSparseJacobianC
66
x::X
77
end
88

9+
__getfield(c::FiniteDiffJacobianCache, ::Val{:jac_prototype}) = c.jac_prototype
10+
911
function sparse_jacobian_cache(fd::Union{AutoSparseFiniteDiff, AutoFiniteDiff},
1012
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
1113
coloring_result = sd(fd, f, x)

src/highlevel/forward_mode.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,20 @@ struct ForwardDiffJacobianCache{CO, CA, J, FX, X} <: AbstractMaybeSparseJacobian
66
x::X
77
end
88

9+
__getfield(c::ForwardDiffJacobianCache, ::Val{:jac_prototype}) = c.jac_prototype
10+
911
struct SparseDiffToolsTag end
1012

13+
__standard_tag(::Nothing, x) = ForwardDiff.Tag(SparseDiffToolsTag(), eltype(x))
14+
__standard_tag(tag, _) = tag
15+
1116
function sparse_jacobian_cache(ad::Union{AutoSparseForwardDiff, AutoForwardDiff},
1217
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
1318
coloring_result = sd(ad, f, x)
1419
fx = fx === nothing ? similar(f(x)) : fx
1520
if coloring_result isa NoMatrixColoring
1621
cache = ForwardDiff.JacobianConfig(f, x, __chunksize(ad, x),
17-
ifelse(ad.tag === nothing, SparseDiffToolsTag(), ad.tag))
22+
__standard_tag(ad.tag, x))
1823
jac_prototype = nothing
1924
else
2025
cache = ForwardColorJacCache(f, x, __chunksize(ad); coloring_result.colorvec,
@@ -29,7 +34,7 @@ function sparse_jacobian_cache(ad::Union{AutoSparseForwardDiff, AutoForwardDiff}
2934
coloring_result = sd(ad, f!, fx, x)
3035
if coloring_result isa NoMatrixColoring
3136
cache = ForwardDiff.JacobianConfig(f!, fx, x, __chunksize(ad, x),
32-
ifelse(ad.tag === nothing, SparseDiffToolsTag(), ad.tag))
37+
__standard_tag(ad.tag, x))
3338
jac_prototype = nothing
3439
else
3540
cache = ForwardColorJacCache(f!, x, __chunksize(ad); coloring_result.colorvec,
@@ -44,7 +49,8 @@ function sparse_jacobian!(J::AbstractMatrix, _, cache::ForwardDiffJacobianCache,
4449
if cache.cache isa ForwardColorJacCache
4550
forwarddiff_color_jacobian(J, f, x, cache.cache) # Use Sparse ForwardDiff
4651
else
47-
ForwardDiff.jacobian!(J, f, x, cache.cache) # Don't try to exploit sparsity
52+
# Disable tag checking since we set the tag to our custom tag
53+
ForwardDiff.jacobian!(J, f, x, cache.cache, Val(false)) # Don't try to exploit sparsity
4854
end
4955
return J
5056
end
@@ -54,7 +60,8 @@ function sparse_jacobian!(J::AbstractMatrix, _, cache::ForwardDiffJacobianCache,
5460
if cache.cache isa ForwardColorJacCache
5561
forwarddiff_color_jacobian!(J, f!, x, cache.cache) # Use Sparse ForwardDiff
5662
else
57-
ForwardDiff.jacobian!(J, f!, fx, x, cache.cache) # Don't try to exploit sparsity
63+
# Disable tag checking since we set the tag to our custom tag
64+
ForwardDiff.jacobian!(J, f!, fx, x, cache.cache, Val(false)) # Don't try to exploit sparsity
5865
end
5966
return J
6067
end

src/highlevel/reverse_mode.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ struct ReverseModeJacobianCache{CO, CA, J, FX, X, I} <: AbstractMaybeSparseJacob
77
idx_vec::I
88
end
99

10+
__getfield(c::ReverseModeJacobianCache, ::Val{:jac_prototype}) = c.jac_prototype
11+
1012
function sparse_jacobian_cache(ad::Union{AutoEnzyme, AbstractReverseMode},
1113
sd::AbstractMaybeSparsityDetection, f::F, x; fx = nothing) where {F}
1214
fx = fx === nothing ? similar(f(x)) : fx

0 commit comments

Comments
 (0)