Skip to content

Commit e3846c8

Browse files
Merge pull request #873 from AayushSabharwal/as/remake-initializeprob
refactor: change `remake_initialization_data`
2 parents c15f8a4 + 4d44eff commit e3846c8

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

src/remake.jl

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ function remake(prob::ODEProblem; f = missing,
125125

126126
if f === missing
127127
if build_initializeprob
128-
initialization_data = remake_initialization_data(
129-
prob.f.sys, prob.f, u0, tspan[1], p)
128+
initialization_data = remake_initialization_data_compat_wrapper(
129+
prob.f.sys, prob.f, u0, tspan[1], p, newu0, newp)
130130
else
131131
initialization_data = nothing
132132
end
@@ -203,16 +203,32 @@ function remake_initializeprob(sys, scimlfn, u0, t0, p)
203203
end
204204

205205
"""
206-
remake_initialization_data(sys, scimlfn, u0, t0, p)
206+
$(TYPEDSIGNATURES)
207+
208+
Wrapper around `remake_initialization_data` for backward compatibility when `newu0` and
209+
`newp` were not arguments.
210+
"""
211+
function remake_initialization_data_compat_wrapper(sys, scimlfn, u0, t0, p, newu0, newp)
212+
if hasmethod(remake_initialization_data,
213+
Tuple{typeof(sys), typeof(scimlfn), typeof(u0), typeof(t0), typeof(p)})
214+
remake_initialization_data(sys, scimlfn, u0, t0, p)
215+
else
216+
remake_initialization_data(sys, scimlfn, u0, t0, p, newu0, newp)
217+
end
218+
end
219+
220+
"""
221+
remake_initialization_data(sys, scimlfn, u0, t0, p, newu0, newp)
207222
208223
Re-create the initialization data present in the function `scimlfn`, using the
209-
associated system `sys` and the user provided new values of `u0`, initial time `t0` and
210-
`p`. By default, this calls `remake_initializeprob` for backward compatibility and
211-
attempts to construct an `OverrideInitData` from the result.
224+
associated system `sys`, the user provided new values of `u0`, initial time `t0`,
225+
user-provided `p`, new u0 vector `newu0` and new parameter object `newp`. By default,
226+
this calls `remake_initializeprob` for backward compatibility and attempts to construct
227+
an `OverrideInitData` from the result.
212228
213229
Note that `u0` or `p` may be `missing` if the user does not provide a value for them.
214230
"""
215-
function remake_initialization_data(sys, scimlfn, u0, t0, p)
231+
function remake_initialization_data(sys, scimlfn, u0, t0, p, newu0, newp)
216232
return reconstruct_initialization_data(
217233
nothing, remake_initializeprob(sys, scimlfn, u0, t0, p)...)
218234
end

test/downstream/adjoints.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ gs_ts, = Zygote.gradient(sol) do sol
6868
sum(sum.(sol[[lorenz1.x, lorenz2.x], :]))
6969
end
7070

71-
@test_broken all(map(x -> x == true_grad_vecsym, gs_ts))
71+
@test all(map(x -> x == true_grad_vecsym, gs_ts))
7272

7373
# BatchedInterface AD
7474
@variables x(t)=1.0 y(t)=1.0 z(t)=1.0 w(t)=1.0

0 commit comments

Comments
 (0)