Skip to content

Commit f0b8470

Browse files
Merge pull request #2445 from SciML/fsal_du
Improve NULL FSAL and fix du for non-FSAL
2 parents 951df5d + 624a9f9 commit f0b8470

File tree

13 files changed

+93
-59
lines changed

13 files changed

+93
-59
lines changed

lib/OrdinaryDiffEqBDF/src/dae_caches.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ end
1515
end
1616

1717
# Not FSAL
18-
get_fsalfirstlast(cache::DImplicitEulerCache, u) = (u, u)
18+
get_fsalfirstlast(cache::DImplicitEulerCache, u) = (nothing, nothing)
1919

2020
mutable struct DImplicitEulerConstantCache{N} <: OrdinaryDiffEqConstantCache
2121
nlsolver::N

lib/OrdinaryDiffEqCore/src/caches/basic_caches.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,14 @@ mutable struct CompositeCache{T, F} <: OrdinaryDiffEqCache
1313
current::Int
1414
end
1515

16-
get_fsalfirstlast(cache::CompositeCache, u) = get_fsalfirstlast(cache.caches[1], u)
16+
function get_fsalfirstlast(cache::CompositeCache, u)
17+
_x = get_fsalfirstlast(cache.caches[1], u)
18+
if first(_x) !== nothing
19+
return _x
20+
else
21+
return get_fsalfirstlast(cache.caches[2], u)
22+
end
23+
end
1724

1825
mutable struct DefaultCache{T1, T2, T3, T4, T5, T6, A, F, uType} <: OrdinaryDiffEqCache
1926
args::A

lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ end
5959
@inline function DiffEqBase.get_du(integrator::ODEIntegrator)
6060
isdiscretecache(integrator.cache) &&
6161
error("Derivatives are not defined for this stepper.")
62-
return if isdefined(integrator, :fsallast)
62+
return if isfsal(integrator.alg)
6363
integrator.fsallast
6464
else
6565
integrator(integrator.t, Val{1})
@@ -72,7 +72,7 @@ end
7272
if isdiscretecache(integrator.cache)
7373
out .= integrator.cache.tmp
7474
else
75-
return if isdefined(integrator, :fsallast) &&
75+
return if isfsal(integrator.alg) &&
7676
!has_stiff_interpolation(integrator.alg)
7777
# Special stiff interpolations do not store the
7878
# right value in fsallast
@@ -221,8 +221,8 @@ function resize!(integrator::ODEIntegrator, i::Int)
221221
# may be required for things like units
222222
c !== nothing && resize!(c, i)
223223
end
224-
resize!(integrator.fsalfirst, i)
225-
resize!(integrator.fsallast, i)
224+
!isnothing(integrator.fsalfirst) && resize!(integrator.fsalfirst, i)
225+
!isnothing(integrator.fsallast) && resize!(integrator.fsallast, i)
226226
resize_f!(integrator.f, i)
227227
resize_nlsolver!(integrator, i)
228228
resize_J_W!(cache, integrator, i)
@@ -235,8 +235,8 @@ function resize!(integrator::ODEIntegrator, i::NTuple{N, Int}) where {N}
235235
for c in full_cache(cache)
236236
resize!(c, i)
237237
end
238-
resize!(integrator.fsalfirst, i)
239-
resize!(integrator.fsallast, i)
238+
!isnothing(integrator.fsalfirst) && resize!(integrator.fsalfirst, i)
239+
!isnothing(integrator.fsallast) && resize!(integrator.fsallast, i)
240240
resize_f!(integrator.f, i)
241241
# TODO the parts below need to be adapted for implicit methods
242242
isdefined(integrator.cache, :nlsolver) && resize_nlsolver!(integrator, i)

lib/OrdinaryDiffEqCore/src/perform_step/composite_perform_step.jl

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -33,41 +33,41 @@ function initialize!(integrator, cache::DefaultCache)
3333
u = integrator.u
3434
if cache.current == 1
3535
fsalfirst, fsallast = get_fsalfirstlast(cache.cache1, u)
36-
integrator.fsalfirst = fsalfirst
37-
integrator.fsallast = fsallast
36+
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
37+
!isnothing(fsallast) && (integrator.fsallast = fsallast)
3838
initialize!(integrator, cache.cache1)
3939
elseif cache.current == 2
4040
fsalfirst, fsallast = get_fsalfirstlast(cache.cache2, u)
41-
integrator.fsalfirst = fsalfirst
42-
integrator.fsallast = fsallast
41+
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
42+
!isnothing(fsallast) && (integrator.fsallast = fsallast)
4343
initialize!(integrator, cache.cache2)
4444
# the controller was initialized by default for algs[1]
4545
reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[2])
4646
elseif cache.current == 3
4747
fsalfirst, fsallast = get_fsalfirstlast(cache.cache3, u)
48-
integrator.fsalfirst = fsalfirst
49-
integrator.fsallast = fsallast
48+
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
49+
!isnothing(fsallast) && (integrator.fsallast = fsallast)
5050
initialize!(integrator, cache.cache3)
5151
# the controller was initialized by default for algs[1]
5252
reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[3])
5353
elseif cache.current == 4
5454
fsalfirst, fsallast = get_fsalfirstlast(cache.cache4, u)
55-
integrator.fsalfirst = fsalfirst
56-
integrator.fsallast = fsallast
55+
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
56+
!isnothing(fsallast) && (integrator.fsallast = fsallast)
5757
initialize!(integrator, cache.cache4)
5858
# the controller was initialized by default for algs[1]
5959
reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[4])
6060
elseif cache.current == 5
6161
fsalfirst, fsallast = get_fsalfirstlast(cache.cache5, u)
62-
integrator.fsalfirst = fsalfirst
63-
integrator.fsallast = fsallast
62+
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
63+
!isnothing(fsallast) && (integrator.fsallast = fsallast)
6464
initialize!(integrator, cache.cache5)
6565
# the controller was initialized by default for algs[1]
6666
reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[5])
6767
elseif cache.current == 6
6868
fsalfirst, fsallast = get_fsalfirstlast(cache.cache6, u)
69-
integrator.fsalfirst = fsalfirst
70-
integrator.fsallast = fsallast
69+
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
70+
!isnothing(fsallast) && (integrator.fsallast = fsallast)
7171
initialize!(integrator, cache.cache6)
7272
# the controller was initialized by default for algs[1]
7373
reset_alg_dependent_opts!(integrator.opts.controller, algs[1], algs[6])
@@ -80,21 +80,21 @@ function initialize!(integrator, cache::CompositeCache)
8080
u = integrator.u
8181
if cache.current == 1
8282
fsalfirst, fsallast = get_fsalfirstlast(cache.caches[1], u)
83-
integrator.fsalfirst = fsalfirst
84-
integrator.fsallast = fsallast
83+
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
84+
!isnothing(fsallast) && (integrator.fsallast = fsallast)
8585
initialize!(integrator, @inbounds(cache.caches[1]))
8686
elseif cache.current == 2
8787
fsalfirst, fsallast = get_fsalfirstlast(cache.caches[2], u)
88-
integrator.fsalfirst = fsalfirst
89-
integrator.fsallast = fsallast
88+
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
89+
!isnothing(fsallast) && (integrator.fsallast = fsallast)
9090
initialize!(integrator, @inbounds(cache.caches[2]))
9191
# the controller was initialized by default for integrator.alg.algs[1]
9292
reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1],
9393
integrator.alg.algs[2])
9494
else
9595
fsalfirst, fsallast = get_fsalfirstlast(cache.caches[cache.current], u)
96-
integrator.fsalfirst = fsalfirst
97-
integrator.fsallast = fsallast
96+
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
97+
!isnothing(fsallast) && (integrator.fsallast = fsallast)
9898
initialize!(integrator, @inbounds(cache.caches[cache.current]))
9999
reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1],
100100
integrator.alg.algs[cache.current])
@@ -107,13 +107,13 @@ function initialize!(integrator, cache::CompositeCache{Tuple{T1, T2}, F}) where
107107
u = integrator.u
108108
if cache.current == 1
109109
fsalfirst, fsallast = get_fsalfirstlast(cache.caches[1], u)
110-
integrator.fsalfirst = fsalfirst
111-
integrator.fsallast = fsallast
110+
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
111+
!isnothing(fsallast) && (integrator.fsallast = fsallast)
112112
initialize!(integrator, @inbounds(cache.caches[1]))
113113
elseif cache.current == 2
114114
fsalfirst, fsallast = get_fsalfirstlast(cache.caches[2], u)
115-
integrator.fsalfirst = fsalfirst
116-
integrator.fsallast = fsallast
115+
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
116+
!isnothing(fsallast) && (integrator.fsallast = fsallast)
117117
initialize!(integrator, @inbounds(cache.caches[2]))
118118
reset_alg_dependent_opts!(integrator.opts.controller, integrator.alg.algs[1],
119119
integrator.alg.algs[2])
@@ -173,13 +173,13 @@ function choose_algorithm!(integrator,
173173
cache.current = new_current
174174
if new_current == 1
175175
fsalfirst, fsallast = get_fsalfirstlast(cache.caches[1], u)
176-
integrator.fsalfirst = fsalfirst
177-
integrator.fsallast = fsallast
176+
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
177+
!isnothing(fsallast) && (integrator.fsallast = fsallast)
178178
initialize!(integrator, @inbounds(cache.caches[1]))
179179
elseif new_current == 2
180180
fsalfirst, fsallast = get_fsalfirstlast(cache.caches[2], u)
181-
integrator.fsalfirst = fsalfirst
182-
integrator.fsallast = fsallast
181+
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
182+
!isnothing(fsallast) && (integrator.fsallast = fsallast)
183183
initialize!(integrator, @inbounds(cache.caches[2]))
184184
end
185185
if old_current == 1 && new_current == 2
@@ -206,38 +206,38 @@ function choose_algorithm!(integrator, cache::DefaultCache)
206206
init_ith_default_cache(cache, algs, new_current)
207207
if new_current == 1
208208
fsalfirst, fsallast = get_fsalfirstlast(cache.cache1, u)
209-
integrator.fsalfirst = fsalfirst
210-
integrator.fsallast = fsallast
209+
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
210+
!isnothing(fsallast) && (integrator.fsallast = fsallast)
211211
initialize!(integrator, @inbounds(cache.cache1))
212212
new_cache = cache.cache1
213213
elseif new_current == 2
214214
fsalfirst, fsallast = get_fsalfirstlast(cache.cache2, u)
215-
integrator.fsalfirst = fsalfirst
216-
integrator.fsallast = fsallast
215+
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
216+
!isnothing(fsallast) && (integrator.fsallast = fsallast)
217217
initialize!(integrator, @inbounds(cache.cache2))
218218
new_cache = cache.cache2
219219
elseif new_current == 3
220220
fsalfirst, fsallast = get_fsalfirstlast(cache.cache3, u)
221-
integrator.fsalfirst = fsalfirst
222-
integrator.fsallast = fsallast
221+
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
222+
!isnothing(fsallast) && (integrator.fsallast = fsallast)
223223
initialize!(integrator, @inbounds(cache.cache3))
224224
new_cache = cache.cache3
225225
elseif new_current == 4
226226
fsalfirst, fsallast = get_fsalfirstlast(cache.cache4, u)
227-
integrator.fsalfirst = fsalfirst
228-
integrator.fsallast = fsallast
227+
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
228+
!isnothing(fsallast) && (integrator.fsallast = fsallast)
229229
initialize!(integrator, @inbounds(cache.cache4))
230230
new_cache = cache.cache4
231231
elseif new_current == 5
232232
fsalfirst, fsallast = get_fsalfirstlast(cache.cache5, u)
233-
integrator.fsalfirst = fsalfirst
234-
integrator.fsallast = fsallast
233+
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
234+
!isnothing(fsallast) && (integrator.fsallast = fsallast)
235235
initialize!(integrator, @inbounds(cache.cache5))
236236
new_cache = cache.cache5
237237
elseif new_current == 6
238238
fsalfirst, fsallast = get_fsalfirstlast(cache.cache6, u)
239-
integrator.fsalfirst = fsalfirst
240-
integrator.fsallast = fsallast
239+
!isnothing(fsalfirst) && (integrator.fsalfirst = fsalfirst)
240+
!isnothing(fsallast) && (integrator.fsallast = fsallast)
241241
initialize!(integrator, @inbounds(cache.cache6))
242242
new_cache = cache.cache6
243243
end

lib/OrdinaryDiffEqCore/src/solve.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -469,14 +469,14 @@ function DiffEqBase.__init(
469469
reinitiailize = true
470470
saveiter = 0 # Starts at 0 so first save is at 1
471471
saveiter_dense = 0
472-
faslfirst, fsallast = get_fsalfirstlast(cache, rate_prototype)
472+
fsalfirst, fsallast = get_fsalfirstlast(cache, rate_prototype)
473473

474474
integrator = ODEIntegrator{typeof(_alg), isinplace(prob), uType, typeof(du),
475475
tType, typeof(p),
476476
typeof(eigen_est), typeof(EEst),
477477
QT, typeof(tdir), typeof(k), SolType,
478478
FType, cacheType,
479-
typeof(opts), typeof(faslfirst),
479+
typeof(opts), typeof(fsalfirst),
480480
typeof(last_event_error), typeof(callback_cache),
481481
typeof(initializealg), typeof(differential_vars)}(
482482
sol, u, du, k, t, tType(dt), f, p,
@@ -496,7 +496,7 @@ function DiffEqBase.__init(
496496
isout, reeval_fsal,
497497
u_modified, reinitiailize, isdae,
498498
opts, stats, initializealg, differential_vars,
499-
faslfirst, fsallast)
499+
fsalfirst, fsallast)
500500

501501
if initialize_integrator
502502
if isdae || SciMLBase.has_initializeprob(prob.f)

lib/OrdinaryDiffEqFunctionMap/src/functionmap_caches.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
uprev::uType
44
tmp::rateType
55
end
6-
get_fsalfirstlast(cache::FunctionMapCache, u) = (cache.u, cache.uprev)
6+
get_fsalfirstlast(cache::FunctionMapCache, u) = (nothing, nothing)
77

88
function alg_cache(alg::FunctionMap, u, rate_prototype, ::Type{uEltypeNoUnits},
99
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,

lib/OrdinaryDiffEqLowStorageRK/src/low_storage_rk_perform_step.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,13 @@ end
3333
integrator.u = u
3434
end
3535

36-
get_fsalfirstlast(cache::LowStorageRK2NCache, u) = (cache.k, cache.k)
36+
get_fsalfirstlast(cache::LowStorageRK2NCache, u) = (nothing, nothing)
3737

3838
function initialize!(integrator, cache::LowStorageRK2NCache)
3939
@unpack k, tmp, williamson_condition = cache
4040
integrator.kshortsize = 1
4141
resize!(integrator.k, integrator.kshortsize)
4242
integrator.k[1] = k
43-
integrator.fsalfirst = k # used for get_du
44-
integrator.fsallast = k
4543
integrator.f(k, integrator.uprev, integrator.p, integrator.t) # FSAL for interpolation
4644
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
4745
end

lib/OrdinaryDiffEqPDIRK/src/pdirk_caches.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
end
99

1010
# Non-FSAL
11-
get_fsalfirstlast(cache::PDIRK44Cache, u) = (cache.u, cache.uprev)
11+
get_fsalfirstlast(cache::PDIRK44Cache, u) = (nothing, nothing)
1212

1313
struct PDIRK44ConstantCache{N, TabType} <: OrdinaryDiffEqConstantCache
1414
nlsolver::N

lib/OrdinaryDiffEqRosenbrock/src/generic_rosenbrock.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ function gen_cache_struct(tab::RosenbrockTableau,cachename::Symbol,constcachenam
176176
end
177177
end
178178
cacheexpr=quote
179-
@cache mutable struct $cachename{uType,rateType,uNoUnitsType,JType,WType,TabType,TFType,UFType,F,JCType,GCType} <: RosenbrockMutableCache
179+
@cache mutable struct $cachename{uType,rateType,uNoUnitsType,JType,WType,TabType,TFType,UFType,F,JCType,GCType} <: GenericRosenbrockMutableCache
180180
u::uType
181181
uprev::uType
182182
du::rateType

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
abstract type RosenbrockMutableCache <: OrdinaryDiffEqMutableCache end
2+
abstract type GenericRosenbrockMutableCache <: RosenbrockMutableCache end
23
abstract type RosenbrockConstantCache <: OrdinaryDiffEqConstantCache end
34

45
# Fake values since non-FSAL
5-
get_fsalfirstlast(cache::RosenbrockMutableCache, u) = (zero(u), zero(u))
6+
get_fsalfirstlast(cache::RosenbrockMutableCache, u) = (nothing, nothing)
7+
get_fsalfirstlast(cache::GenericRosenbrockMutableCache, u) = (cache.fsalfirst, cache.fsallast)
68

79
################################################################################
810

0 commit comments

Comments
 (0)