Skip to content

Commit 20e46e7

Browse files
Improve NULL FSAL and fix du for non-FSAL
Fixes the downstream SciMLSensitivity
1 parent 951df5d commit 20e46e7

File tree

9 files changed

+37
-10
lines changed

9 files changed

+37
-10
lines changed

lib/OrdinaryDiffEqBDF/src/dae_caches.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ end
1515
end
1616

1717
# Not FSAL
18-
get_fsalfirstlast(cache::DImplicitEulerCache, u) = (u, u)
18+
get_fsalfirstlast(cache::DImplicitEulerCache, u) = (nothing, nothing)
1919

2020
mutable struct DImplicitEulerConstantCache{N} <: OrdinaryDiffEqConstantCache
2121
nlsolver::N

lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ end
5959
@inline function DiffEqBase.get_du(integrator::ODEIntegrator)
6060
isdiscretecache(integrator.cache) &&
6161
error("Derivatives are not defined for this stepper.")
62-
return if isdefined(integrator, :fsallast)
62+
return if isfsal(integrator.alg)
6363
integrator.fsallast
6464
else
6565
integrator(integrator.t, Val{1})
@@ -72,7 +72,7 @@ end
7272
if isdiscretecache(integrator.cache)
7373
out .= integrator.cache.tmp
7474
else
75-
return if isdefined(integrator, :fsallast) &&
75+
return if isfsal(integrator.alg) &&
7676
!has_stiff_interpolation(integrator.alg)
7777
# Special stiff interpolations do not store the
7878
# right value in fsallast

lib/OrdinaryDiffEqFunctionMap/src/functionmap_caches.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
uprev::uType
44
tmp::rateType
55
end
6-
get_fsalfirstlast(cache::FunctionMapCache, u) = (cache.u, cache.uprev)
6+
get_fsalfirstlast(cache::FunctionMapCache, u) = (nothing, nothing)
77

88
function alg_cache(alg::FunctionMap, u, rate_prototype, ::Type{uEltypeNoUnits},
99
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,

lib/OrdinaryDiffEqLowStorageRK/src/low_storage_rk_perform_step.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ end
3333
integrator.u = u
3434
end
3535

36-
get_fsalfirstlast(cache::LowStorageRK2NCache, u) = (cache.k, cache.k)
36+
get_fsalfirstlast(cache::LowStorageRK2NCache, u) = (nothing, nothing)
3737

3838
function initialize!(integrator, cache::LowStorageRK2NCache)
3939
@unpack k, tmp, williamson_condition = cache

lib/OrdinaryDiffEqPDIRK/src/pdirk_caches.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
end
99

1010
# Non-FSAL
11-
get_fsalfirstlast(cache::PDIRK44Cache, u) = (cache.u, cache.uprev)
11+
get_fsalfirstlast(cache::PDIRK44Cache, u) = (nothing, nothing)
1212

1313
struct PDIRK44ConstantCache{N, TabType} <: OrdinaryDiffEqConstantCache
1414
nlsolver::N

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ abstract type RosenbrockMutableCache <: OrdinaryDiffEqMutableCache end
22
abstract type RosenbrockConstantCache <: OrdinaryDiffEqConstantCache end
33

44
# Fake values since non-FSAL
5-
get_fsalfirstlast(cache::RosenbrockMutableCache, u) = (zero(u), zero(u))
5+
get_fsalfirstlast(cache::RosenbrockMutableCache, u) = (nothing, nothing)
66

77
################################################################################
88

lib/OrdinaryDiffEqVerner/src/verner_caches.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ end
8787
end
8888

8989
# fake values since non-FSAL method
90-
get_fsalfirstlast(cache::Vern7Cache, u) = (cache.k1, cache.k2)
90+
get_fsalfirstlast(cache::Vern7Cache, u) = (nothing, nothing)
9191

9292
function alg_cache(alg::Vern7, u, rate_prototype, ::Type{uEltypeNoUnits},
9393
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
@@ -153,7 +153,7 @@ end
153153
end
154154

155155
# fake values since non-FSAL method
156-
get_fsalfirstlast(cache::Vern8Cache, u) = (cache.k1, cache.k2)
156+
get_fsalfirstlast(cache::Vern8Cache, u) = (nothing, nothing)
157157

158158
function alg_cache(alg::Vern8, u, rate_prototype, ::Type{uEltypeNoUnits},
159159
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
@@ -227,7 +227,7 @@ end
227227
end
228228

229229
# fake values since non-FSAL method
230-
get_fsalfirstlast(cache::Vern9Cache, u) = (cache.k1, cache.k2)
230+
get_fsalfirstlast(cache::Vern9Cache, u) = (nothing, nothing)
231231

232232
function alg_cache(alg::Vern9, u, rate_prototype, ::Type{uEltypeNoUnits},
233233
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,

test/interface/get_du.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using OrdinaryDiffEq, OrdinaryDiffEqCore, Test
2+
function lorenz!(du, u, p, t)
3+
du[1] = 10.0(u[2] - u[1])
4+
du[2] = u[1] * (28.0 - u[3]) - u[2]
5+
du[3] = u[1] * u[2] - (8 / 3) * u[3]
6+
end
7+
u0 = [1.0; 0.0; 0.0]
8+
tspan = (0.0, 3.0)
9+
condition(u,t,integrator) = t == 0.2
10+
cache = zeros(3)
11+
affect!(integrator) = cache .= get_du(integrator)
12+
dusave = DiscreteCallback(condition, affect!)
13+
affect2!(integrator) = get_du!(cache, integrator)
14+
dusave_inplace = DiscreteCallback(condition, affect2!)
15+
16+
prob = ODEProblem(lorenz!, u0, tspan)
17+
sol = solve(prob, Tsit5(), tstops = [0.2], callback = dusave, abstol=1e-12, reltol=1e-12)
18+
res = copy(cache)
19+
20+
for alg in [Vern6(), Vern7(), Vern8(), Vern9(), Rodas4(), Rodas4P(), Rodas5(), Rodas5P(), TRBDF2(), KenCarp4(), FBDF(), QNDF()]
21+
sol = solve(prob, alg, tstops = [0.2], callback = dusave, abstol=1e-12, reltol=1e-12)
22+
@test res cache rtol=1e-5
23+
24+
sol = solve(prob, alg, tstops = [0.2], callback = dusave_inplace, abstol=1e-12, reltol=1e-12)
25+
@test res cache rtol=1e-5
26+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ end
6666
@time @safetestset "Linear Solver Split ODE Tests" include("interface/linear_solver_split_ode_test.jl")
6767
@time @safetestset "Sparse Diff Tests" include("interface/sparsediff_tests.jl")
6868
@time @safetestset "Enum Tests" include("interface/enums.jl")
69+
@time @safetestset "Enum Tests" include("interface/get_du.jl")
6970
@time @safetestset "Mass Matrix Tests" include("interface/mass_matrix_tests.jl")
7071
@time @safetestset "W-Operator prototype tests" include("interface/wprototype_tests.jl")
7172
end

0 commit comments

Comments
 (0)