Skip to content

Commit 6379a09

Browse files
fix: run late_binding_update_u0_p in reinit!
1 parent f3d4218 commit 6379a09

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,8 @@ function terminate!(integrator::ODEIntegrator, retcode = ReturnCode.Terminated)
321321
integrator.opts.tstops.valtree = typeof(integrator.opts.tstops.valtree)()
322322
end
323323

324+
const EMPTY_ARRAY_OF_PAIRS = Pair[]
325+
324326
DiffEqBase.has_reinit(integrator::ODEIntegrator) = true
325327
function DiffEqBase.reinit!(integrator::ODEIntegrator, u0 = integrator.sol.prob.u0;
326328
t0 = integrator.sol.prob.tspan[1],
@@ -335,6 +337,23 @@ function DiffEqBase.reinit!(integrator::ODEIntegrator, u0 = integrator.sol.prob.
335337
reinit_callbacks = true, initialize_save = true,
336338
reinit_cache = true,
337339
reinit_retcode = true)
340+
if reinit_dae && SciMLBase.has_initializeprob(integrator.sol.prob.f)
341+
# This is `remake` infrastructure. `reinit!` is somewhat like `remake` for
342+
# integrators, so we reuse some of the same pieces. If we pass `integrator.p`
343+
# for `p`, it means we don't want to change it. If we pass `missing`, this
344+
# function may (correctly) assume `newp` aliases `prob.p` and copy it, which we
345+
# want to avoid. So we pass an empty array of pairs to make it think this is
346+
# a symbolic `remake` and it can modify `newp` inplace. The array of pairs is a
347+
# const global to avoid allocating every time this function is called.
348+
u0, newp = SciMLBase.late_binding_update_u0_p(integrator.sol.prob, u0,
349+
EMPTY_ARRAY_OF_PAIRS, t0, u0, integrator.p)
350+
if newp !== integrator.p
351+
integrator.p = newp
352+
sol = integrator.sol
353+
@reset sol.prob.p = newp
354+
integrator.sol = sol
355+
end
356+
end
338357
if isinplace(integrator.sol.prob)
339358
recursivecopy!(integrator.u, u0)
340359
recursivecopy!(integrator.uprev, integrator.u)

test/interface/dae_initialize_integration.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,20 @@ sol = solve(prob, Rodas5P(), dt = 1e-10)
7676
@test sol[1] == [1.0]
7777
@test sol[2] [0.9999999998]
7878
@test sol[end] [-1.0]
79+
80+
@testset "`reinit!` updates initial parameters" begin
81+
# https://github.com/SciML/ModelingToolkit.jl/issues/3451
82+
# https://github.com/SciML/ModelingToolkit.jl/issues/3504
83+
@variables x(t) y(t)
84+
@parameters c1 c2
85+
@mtkbuild sys = ODESystem([D(x) ~ -c1 * x + c2 * y, D(y) ~ c1 * x - c2 * y], t)
86+
prob = ODEProblem(sys, [1.0, 2.0], (0.0, 1.0), [c1 => 1.0, c2 => 2.0])
87+
@test prob.ps[Initial(x)] 1.0
88+
@test prob.ps[Initial(y)] 2.0
89+
integ = init(prob, Tsit5())
90+
@test integ.ps[Initial(x)] 1.0
91+
@test integ.ps[Initial(y)] 2.0
92+
reinit!(integ, [2.0, 3.0])
93+
@test integ.ps[Initial(x)] 2.0
94+
@test integ.ps[Initial(y)] 3.0
95+
end

0 commit comments

Comments
 (0)