Skip to content

Commit 84c0dc3

Browse files
committed
Fix cache sizes
1 parent 0a5398c commit 84c0dc3

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

src/differentiation/jaches_products.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -264,26 +264,27 @@ f(du, u) # Otherwise
264264
"""
265265
function JacVec(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing,
266266
autodiff = AutoForwardDiff(), tag = DeivVecTag(), kwargs...)
267+
ff = JacFunctionWrapper(f, fu, u, p, t)
268+
fu === nothing && (fu = __internal_oop(ff) ? ff(u) : u)
269+
267270
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
268-
cache1 = similar(u)
271+
cache1 = similar(fu)
269272
cache2 = similar(u)
270273

271274
(cache1, cache2), num_jacvec, num_jacvec!
272275
elseif autodiff isa AutoForwardDiff
273276
cache1 = Dual{
274277
typeof(ForwardDiff.Tag(tag, eltype(u))), eltype(u), 1,
275278
}.(u, ForwardDiff.Partials.(tuple.(u)))
276-
277-
cache2 = copy(cache1)
279+
cache2 = Dual{
280+
typeof(ForwardDiff.Tag(tag, eltype(fu))), eltype(fu), 1,
281+
}.(fu, ForwardDiff.Partials.(tuple.(fu)))
278282

279283
(cache1, cache2), auto_jacvec, auto_jacvec!
280284
else
281285
error("Set autodiff to either AutoForwardDiff(), or AutoFiniteDiff()")
282286
end
283287

284-
ff = JacFunctionWrapper(f, fu, u, p, t)
285-
fu === nothing && (fu = __internal_oop(ff) ? ff(u) : u)
286-
287288
op = FwdModeAutoDiffVecProd(ff, u, cache, vecprod, vecprod!)
288289

289290
return FunctionOperator(op, u, fu; isinplace = Val(true), outofplace = Val(true), p, t,

0 commit comments

Comments
 (0)