Skip to content

Commit afd385b

Browse files
changes
1 parent 60ab391 commit afd385b

File tree

3 files changed

+45
-12
lines changed

3 files changed

+45
-12
lines changed

lib/OrdinaryDiffEqFIRK/src/controllers.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
function step_accept_controller!(integrator, controller::PredictiveController, alg::AdaptiveRadau, q)
22
@unpack qmin, qmax, gamma, qsteady_min, qsteady_max = integrator.opts
33
@unpack cache = integrator
4-
@unpack num_stages, step, iter, hist_iter = cache
4+
@unpack num_stages, step, iter, hist_iter, index = cache
55

66
EEst = DiffEqBase.value(integrator.EEst)
77

@@ -25,12 +25,14 @@ function step_accept_controller!(integrator, controller::PredictiveController, a
2525
max_stages = (alg.max_order - 1) ÷ 4 * 2 + 1
2626
min_stages = (alg.min_order - 1) ÷ 4 * 2 + 1
2727
if (step > 10)
28-
if (hist_iter < 2.6 && num_stages <= max_stages)
28+
if (hist_iter < 2.6 && num_stages < max_stages)
2929
cache.num_stages += 2
30+
cache.index += 1
3031
cache.step = 1
3132
cache.hist_iter = iter
32-
elseif ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages >= min_stages)
33+
elseif ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > min_stages)
3334
cache.num_stages -= 2
35+
cache.index -= 1
3436
cache.step = 1
3537
cache.hist_iter = iter
3638
end
@@ -48,8 +50,9 @@ function step_reject_controller!(integrator, controller::PredictiveController, a
4850
cache.hist_iter = hist_iter
4951
min_stages = (alg.min_order - 1) ÷ 4 * 2 + 1
5052
if (step > 10)
51-
if ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages >= min_stages)
53+
if ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > min_stages)
5254
cache.num_stages -= 2
55+
cache.index -= 1
5356
cache.step = 1
5457
cache.hist_iter = iter
5558
end

lib/OrdinaryDiffEqFIRK/src/firk_caches.jl

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ mutable struct AdaptiveRadauConstantCache{F, Tab, Tol, Dt, U, JType} <:
497497
num_stages::Int
498498
step::Int
499499
hist_iter::Float64
500+
index::Int
500501
end
501502

502503
function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -518,7 +519,11 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
518519
num_stages = min
519520

520521
tabs = [RadauIIATableau5(uToltype, constvalue(tTypeNoUnits)), RadauIIATableau9(uToltype, constvalue(tTypeNoUnits)), RadauIIATableau13(uToltype, constvalue(tTypeNoUnits))]
521-
i = 9
522+
if (min == 3 || min == 5 || min == 7)
523+
i = 9
524+
else
525+
i = min
526+
end
522527
while i <= max
523528
push!(tabs, RadauIIATableau(uToltype, constvalue(tTypeNoUnits), i))
524529
i += 2
@@ -528,10 +533,20 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
528533
cont[i] = zero(u)
529534
end
530535

536+
if (min == 3)
537+
index = 1
538+
elseif (min == 5)
539+
index = 2
540+
elseif (min == 7)
541+
index = 3
542+
else
543+
index = 4
544+
end
545+
531546
κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)
532547
J = false .* _vec(rate_prototype) .* _vec(rate_prototype)'
533548
AdaptiveRadauConstantCache(uf, tabs, κ, one(uToltype), 10000, cont, dt, dt,
534-
Convergence, J, num_stages, 1, 0.0)
549+
Convergence, J, num_stages, 1, 0.0, index)
535550
end
536551

537552
mutable struct AdaptiveRadauCache{uType, cuType, tType, uNoUnitsType, rateType, JType, W1Type, W2Type,
@@ -578,6 +593,7 @@ mutable struct AdaptiveRadauCache{uType, cuType, tType, uNoUnitsType, rateType,
578593
num_stages::Int
579594
step::Int
580595
hist_iter::Float64
596+
index::Int
581597
end
582598

583599
function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -599,12 +615,26 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
599615
num_stages = min
600616

601617
tabs = [RadauIIATableau5(uToltype, constvalue(tTypeNoUnits)), RadauIIATableau9(uToltype, constvalue(tTypeNoUnits)), RadauIIATableau13(uToltype, constvalue(tTypeNoUnits))]
602-
i = 9
618+
if (min == 3 || min == 5 || min == 7)
619+
i = 9
620+
else
621+
i = min
622+
end
603623
while i <= max
604624
push!(tabs, RadauIIATableau(uToltype, constvalue(tTypeNoUnits), i))
605625
i += 2
606626
end
607627

628+
if (min == 3)
629+
index = 1
630+
elseif (min == 5)
631+
index = 2
632+
elseif (min == 7)
633+
index = 3
634+
else
635+
index = 4
636+
end
637+
608638
κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)
609639

610640
z = Vector{typeof(u)}(undef, max)
@@ -677,6 +707,6 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
677707
uf, tabs, κ, one(uToltype), 10000, tmp,
678708
atmp, jac_config,
679709
linsolve1, linsolve2, rtol, atol, dt, dt,
680-
Convergence, alg.step_limiter!, num_stages, 1, 0.0)
710+
Convergence, alg.step_limiter!, num_stages, 1, 0.0, index)
681711
end
682712

lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,8 +1354,8 @@ end
13541354
@muladd function perform_step!(integrator, cache::AdaptiveRadauConstantCache,
13551355
repeat_step = false)
13561356
@unpack t, dt, uprev, u, f, p = integrator
1357-
@unpack tabs, num_stages = cache
1358-
tab = tabs[(num_stages - 1) ÷ 2]
1357+
@unpack tabs, num_stages, index = cache
1358+
tab = tabs[index]
13591359
@unpack T, TI, γ, α, β, c, e = tab
13601360
@unpack κ, cont = cache
13611361
@unpack internalnorm, abstol, reltol, adaptive = integrator.opts
@@ -1595,8 +1595,8 @@ end
15951595

15961596
@muladd function perform_step!(integrator, cache::AdaptiveRadauCache, repeat_step = false)
15971597
@unpack t, dt, uprev, u, f, p, fsallast, fsalfirst = integrator
1598-
@unpack num_stages, tabs = cache
1599-
tab = tabs[(num_stages - 1) ÷ 2]
1598+
@unpack num_stages, tabs, index = cache
1599+
tab = tabs[index]
16001600
@unpack T, TI, γ, α, β, c, e = tab
16011601
@unpack κ, cont, derivatives, z, w, c_prime, αdt, βdt= cache
16021602
@unpack dw1, ubuff, dw2, cubuff, dw = cache

0 commit comments

Comments
 (0)