Skip to content

Commit 1b537a0

Browse files
Merge pull request #2540 from oscardssmith/os/fix-DefaultCache-type-stability
fix type stability for `DefaultCache`
2 parents 9799ee1 + 64defc0 commit 1b537a0

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

lib/OrdinaryDiffEqCore/src/caches/basic_caches.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ function alg_cache(alg::CompositeAlgorithm{CS, Tuple{A1, A2, A3, A4, A5, A6}}, u
7373
args = (u, rate_prototype, uEltypeNoUnits,
7474
uBottomEltypeNoUnits, tTypeNoUnits, uprev, uprev2, f, t, dt,
7575
reltol, p, calck, Val(V))
76-
argT = map(typeof, args)
76+
# Core.Typeof to turn uEltypeNoUnits into Type{uEltypeNoUnits} rather than DataType
77+
argT = map(Core.Typeof, args)
7778
T1 = Base.promote_op(alg_cache, A1, argT...)
7879
T2 = Base.promote_op(alg_cache, A2, argT...)
7980
T3 = Base.promote_op(alg_cache, A3, argT...)

lib/OrdinaryDiffEqCore/src/dense/generic_dense.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,12 @@ function default_ode_interpolant(
168168
return ode_interpolant(Θ, integrator.dt, integrator.uprev,
169169
integrator.u, integrator.k, cache.cache5, idxs,
170170
deriv, integrator.differential_vars)
171-
else # alg_choice == 6
171+
elseif alg_choice == 6
172172
return ode_interpolant(Θ, integrator.dt, integrator.uprev,
173173
integrator.u, integrator.k, cache.cache6, idxs,
174174
deriv, integrator.differential_vars)
175+
else
176+
error("DefaultCache invalid alg_choice. File an issue.")
175177
end
176178
end
177179

@@ -227,6 +229,8 @@ end
227229
ode_interpolant!(val, Θ, integrator.dt, integrator.uprev, integrator.u,
228230
integrator.k, integrator.cache.cache6,
229231
idxs, deriv, integrator.differential_vars)
232+
else
233+
error("DefaultCache invalid alg_choice. File an issue.")
230234
end
231235
else
232236
ode_interpolant!(val, Θ, integrator.dt, integrator.uprev, integrator.u,
@@ -256,10 +260,12 @@ function default_ode_interpolant!(
256260
return ode_interpolant!(val, Θ, integrator.dt, integrator.uprev,
257261
integrator.u, integrator.k, cache.cache5, idxs,
258262
deriv, integrator.differential_vars)
259-
else # alg_choice == 6
263+
elseif alg_choice == 6
260264
return ode_interpolant!(val, Θ, integrator.dt, integrator.uprev,
261265
integrator.u, integrator.k, cache.cache6, idxs,
262266
deriv, integrator.differential_vars)
267+
else
268+
error("DefaultCache invalid alg_choice. File an issue.")
263269
end
264270
end
265271

@@ -380,6 +386,8 @@ function default_ode_extrapolant!(
380386
ode_interpolant!(val, Θ, integrator.t - integrator.tprev,
381387
integrator.uprev2, integrator.uprev,
382388
integrator.k, cache.cache6, idxs, deriv, integrator.differential_vars)
389+
else
390+
error("DefaultCache invalid alg_choice. File an issue.")
383391
end
384392
end
385393

@@ -444,6 +452,8 @@ function default_ode_extrapolant(
444452
ode_interpolant(Θ, integrator.t - integrator.tprev,
445453
integrator.uprev2, integrator.uprev,
446454
integrator.k, cache.cache6, idxs, deriv, integrator.differential_vars)
455+
else
456+
error("DefaultCache invalid alg_choice. File an issue.")
447457
end
448458
end
449459

@@ -810,6 +820,8 @@ function ode_interpolation(tval::Number, id::I, idxs, deriv::D, p,
810820
cache.cache6) # update the kcurrent
811821
val = ode_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], ks[i₊],
812822
cache.cache6, idxs, deriv, differential_vars)
823+
else
824+
error("DefaultCache invalid alg_choice. File an issue.")
813825
end
814826
else
815827
_ode_addsteps!(ks[i₊], ts[i₋], timeseries[i₋], timeseries[i₊], dt, f, p,
@@ -892,6 +904,8 @@ function ode_interpolation!(out, tval::Number, id::I, idxs, deriv::D, p,
892904
cache.cache6) # update the kcurrent
893905
ode_interpolant!(out, Θ, dt, timeseries[i₋], timeseries[i₊], ks[i₊],
894906
cache.cache6, idxs, deriv, differential_vars)
907+
else
908+
error("DefaultCache invalid alg_choice. File an issue.")
895909
end
896910
else
897911
_ode_addsteps!(ks[i₊], ts[i₋], timeseries[i₋], timeseries[i₊], dt, f, p,

lib/OrdinaryDiffEqDefault/test/default_solver_tests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ end
4040
prob_rober = ODEProblem(rober, [1.0, 0.0, 0.0], (0.0, 1e3), (0.04, 3e7, 1e4))
4141
sol = solve(prob_rober)
4242
rosensol = solve(prob_rober, AutoTsit5(Rosenbrock23(autodiff = false)))
43+
#test that cache is type stable
44+
@test typeof(sol.interp.cache.cache3) == typeof(rosensol.interp.cache.caches[2])
4345
# test that default has the same performance as AutoTsit5(Rosenbrock23()) (which we expect it to use for this).
4446
@test sol.stats.naccept == rosensol.stats.naccept
4547
@test sol.stats.nf == rosensol.stats.nf
@@ -50,6 +52,8 @@ rosensol = solve(prob_rober, AutoTsit5(Rosenbrock23(autodiff = false)))
5052
sol = solve(prob_rober, reltol = 1e-7, abstol = 1e-7)
5153
rosensol = solve(
5254
prob_rober, AutoVern7(Rodas5P(autodiff = false)), reltol = 1e-7, abstol = 1e-7)
55+
#test that cache is type stable
56+
@test typeof(sol.interp.cache.cache4) == typeof(rosensol.interp.cache.caches[2])
5357
# test that default has the same performance as AutoTsit5(Rosenbrock23()) (which we expect it to use for this).
5458
@test sol.stats.naccept == rosensol.stats.naccept
5559
@test sol.stats.nf == rosensol.stats.nf

0 commit comments

Comments
 (0)