Skip to content

Commit f65bcca

Browse files
Merge pull request #2478 from Shreyas-Ekanathan/master
Add Order Adaptivity to Radau
2 parents 45e7a23 + 3f07e1d commit f65bcca

File tree

8 files changed

+250
-165
lines changed

8 files changed

+250
-165
lines changed
Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
qmax_default(alg::Union{RadauIIA3, RadauIIA5, RadauIIA9}) = 8
1+
qmax_default(alg::Union{RadauIIA3, RadauIIA5, RadauIIA9, AdaptiveRadau}) = 8
22

33
alg_order(alg::RadauIIA3) = 3
44
alg_order(alg::RadauIIA5) = 5
55
alg_order(alg::RadauIIA9) = 9
6-
alg_order(alg::AdaptiveRadau) = 5
6+
alg_order(alg::AdaptiveRadau) = 5 #dummy value
77

88
isfirk(alg::RadauIIA3) = true
99
isfirk(alg::RadauIIA5) = true
@@ -13,3 +13,6 @@ isfirk(alg::AdaptiveRadau) = true
1313
alg_adaptive_order(alg::RadauIIA3) = 1
1414
alg_adaptive_order(alg::RadauIIA5) = 3
1515
alg_adaptive_order(alg::RadauIIA9) = 5
16+
17+
get_current_alg_order(alg::AdaptiveRadau, cache) = cache.num_stages * 2 - 1
18+
get_current_adaptive_order(alg::AdaptiveRadau, cache) = cache.num_stages

lib/OrdinaryDiffEqFIRK/src/algorithms.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,13 @@ struct AdaptiveRadau{CS, AD, F, P, FDT, ST, CJ, Tol, C1, C2, StepLimiter} <:
163163
new_W_γdt_cutoff::C2
164164
controller::Symbol
165165
step_limiter!::StepLimiter
166-
num_stages::Int
166+
min_stages::Int
167+
max_stages::Int
167168
end
168169

169170
function AdaptiveRadau(; chunk_size = Val{0}(), autodiff = Val{true}(),
170171
standardtag = Val{true}(), concrete_jac = nothing,
171-
diff_type = Val{:forward}, num_stages = 3,
172+
diff_type = Val{:forward}, min_stages = 3, max_stages = 7,
172173
linsolve = nothing, precs = DEFAULT_PRECS,
173174
extrapolant = :dense, fast_convergence_cutoff = 1 // 5,
174175
new_W_γdt_cutoff = 1 // 5,
@@ -186,6 +187,6 @@ function AdaptiveRadau(; chunk_size = Val{0}(), autodiff = Val{true}(),
186187
fast_convergence_cutoff,
187188
new_W_γdt_cutoff,
188189
controller,
189-
step_limiter!, num_stages)
190+
step_limiter!, min_stages, max_stages)
190191
end
191192

lib/OrdinaryDiffEqFIRK/src/controllers.jl

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
@inline function stepsize_controller!(integrator, controller::PredictiveController, alg)
22
@unpack qmin, qmax, gamma = integrator.opts
33
EEst = DiffEqBase.value(integrator.EEst)
4-
54
if iszero(EEst)
65
q = inv(qmax)
76
else
@@ -26,6 +25,7 @@ end
2625

2726
function step_accept_controller!(integrator, controller::PredictiveController, alg, q)
2827
@unpack qmin, qmax, gamma, qsteady_min, qsteady_max = integrator.opts
28+
2929
EEst = DiffEqBase.value(integrator.EEst)
3030

3131
if integrator.success_iter > 0
@@ -42,10 +42,68 @@ function step_accept_controller!(integrator, controller::PredictiveController, a
4242
end
4343
integrator.dtacc = integrator.dt
4444
integrator.erracc = max(1e-2, EEst)
45+
46+
return integrator.dt / qacc
47+
end
48+
49+
50+
function step_accept_controller!(integrator, controller::PredictiveController, alg::AdaptiveRadau, q)
51+
@unpack qmin, qmax, gamma, qsteady_min, qsteady_max = integrator.opts
52+
@unpack cache = integrator
53+
@unpack num_stages, step, iter, hist_iter = cache
54+
55+
EEst = DiffEqBase.value(integrator.EEst)
56+
57+
if integrator.success_iter > 0
58+
expo = 1 / (get_current_adaptive_order(alg, integrator.cache) + 1)
59+
qgus = (integrator.dtacc / integrator.dt) *
60+
DiffEqBase.fastpow((EEst^2) / integrator.erracc, expo)
61+
qgus = max(inv(qmax), min(inv(qmin), qgus / gamma))
62+
qacc = max(q, qgus)
63+
else
64+
qacc = q
65+
end
66+
if qsteady_min <= qacc <= qsteady_max
67+
qacc = one(qacc)
68+
end
69+
integrator.dtacc = integrator.dt
70+
integrator.erracc = max(1e-2, EEst)
71+
cache.step = step + 1
72+
hist_iter = hist_iter * 0.8 + iter * 0.2
73+
cache.hist_iter = hist_iter
74+
if (step > 10)
75+
if (hist_iter < 2.6 && num_stages < alg.max_stages)
76+
cache.num_stages += 2
77+
cache.step = 1
78+
cache.hist_iter = iter
79+
elseif ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > alg.min_stages)
80+
cache.num_stages -= 2
81+
cache.step = 1
82+
cache.hist_iter = iter
83+
end
84+
end
4585
return integrator.dt / qacc
4686
end
4787

4888
function step_reject_controller!(integrator, controller::PredictiveController, alg)
4989
@unpack dt, success_iter, qold = integrator
5090
integrator.dt = success_iter == 0 ? 0.1 * dt : dt / qold
5191
end
92+
93+
function step_reject_controller!(integrator, controller::PredictiveController, alg::AdaptiveRadau)
94+
@unpack dt, success_iter, qold = integrator
95+
@unpack cache = integrator
96+
@unpack num_stages, step, iter, hist_iter = cache
97+
integrator.dt = success_iter == 0 ? 0.1 * dt : dt / qold
98+
cache.step = step + 1
99+
hist_iter = hist_iter * 0.8 + iter * 0.2
100+
cache.hist_iter = hist_iter
101+
if (step > 10)
102+
if ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > alg.min_stages)
103+
cache.num_stages -= 2
104+
cache.step = 1
105+
cache.hist_iter = iter
106+
end
107+
end
108+
end
109+

lib/OrdinaryDiffEqFIRK/src/firk_caches.jl

Lines changed: 49 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ end
477477
mutable struct AdaptiveRadauConstantCache{F, Tab, Tol, Dt, U, JType} <:
478478
OrdinaryDiffEqConstantCache
479479
uf::F
480-
tab::Tab
480+
tabs::Vector{Tab}
481481
κ::Tol
482482
ηold::Tol
483483
iter::Int
@@ -486,6 +486,9 @@ mutable struct AdaptiveRadauConstantCache{F, Tab, Tol, Dt, U, JType} <:
486486
W_γdt::Dt
487487
status::NLStatus
488488
J::JType
489+
num_stages::Int
490+
step::Int
491+
hist_iter::Float64
489492
end
490493

491494
function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -494,30 +497,24 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
494497
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
495498
uf = UDerivativeWrapper(f, t, p)
496499
uToltype = constvalue(uBottomEltypeNoUnits)
497-
num_stages = alg.num_stages
498-
499-
if (num_stages == 3)
500-
tab = BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits))
501-
elseif (num_stages == 5)
502-
tab = BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits))
503-
elseif (num_stages == 7)
504-
tab = BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))
505-
elseif iseven(num_stages) || num_stages <3
506-
error("num_stages must be odd and 3 or greater")
507-
else
508-
tab = adaptiveRadauTableau(uToltype, constvalue(tTypeNoUnits), num_stages)
500+
num_stages = alg.min_stages
501+
max = alg.max_stages
502+
tabs = [BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))]
503+
504+
i = 9
505+
while i <= alg.max_stages
506+
push!(tabs, adaptiveRadauTableau(uToltype, constvalue(tTypeNoUnits), i))
507+
i += 2
509508
end
510-
511-
cont = Vector{typeof(u)}(undef, num_stages)
512-
for i in 1: num_stages
509+
cont = Vector{typeof(u)}(undef, max)
510+
for i in 1: max
513511
cont[i] = zero(u)
514512
end
515513

516514
κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)
517515
J = false .* _vec(rate_prototype) .* _vec(rate_prototype)'
518-
519-
AdaptiveRadauConstantCache(uf, tab, κ, one(uToltype), 10000, cont, dt, dt,
520-
Convergence, J)
516+
AdaptiveRadauConstantCache(uf, tabs, κ, one(uToltype), 10000, cont, dt, dt,
517+
Convergence, J, num_stages, 1, 0.0)
521518
end
522519

523520
mutable struct AdaptiveRadauCache{uType, cuType, tType, uNoUnitsType, rateType, JType, W1Type, W2Type,
@@ -544,7 +541,7 @@ mutable struct AdaptiveRadauCache{uType, cuType, tType, uNoUnitsType, rateType,
544541
W1::W1Type #real
545542
W2::Vector{W2Type} #complex
546543
uf::UF
547-
tab::Tab
544+
tabs::Vector{Tab}
548545
κ::Tol
549546
ηold::Tol
550547
iter::Int
@@ -559,6 +556,9 @@ mutable struct AdaptiveRadauCache{uType, cuType, tType, uNoUnitsType, rateType,
559556
W_γdt::Dt
560557
status::NLStatus
561558
step_limiter!::StepLimiter
559+
num_stages::Int
560+
step::Int
561+
hist_iter::Float64
562562
end
563563

564564
function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -567,62 +567,57 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
567567
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
568568
uf = UJacobianWrapper(f, t, p)
569569
uToltype = constvalue(uBottomEltypeNoUnits)
570-
num_stages = alg.num_stages
571-
572-
if (num_stages == 3)
573-
tab = BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits))
574-
elseif (num_stages == 5)
575-
tab = BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits))
576-
elseif (num_stages == 7)
577-
tab = BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))
578-
elseif iseven(num_stages) || num_stages < 3
579-
error("num_stages must be odd and 3 or greater")
580-
else
581-
tab = adaptiveRadauTableau(uToltype, constvalue(tTypeNoUnits), num_stages)
570+
571+
max = alg.max_stages
572+
num_stages = alg.min_stages
573+
574+
tabs = [BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))]
575+
i = 9
576+
while i <= max
577+
push!(tabs, adaptiveRadauTableau(uToltype, constvalue(tTypeNoUnits), i))
578+
i += 2
582579
end
583580

584581
κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)
585582

586-
z = Vector{typeof(u)}(undef, num_stages)
587-
w = Vector{typeof(u)}(undef, num_stages)
588-
for i in 1 : num_stages
583+
z = Vector{typeof(u)}(undef, max)
584+
w = Vector{typeof(u)}(undef, max)
585+
for i in 1 : max
589586
z[i] = w[i] = zero(u)
590587
end
591588

592-
c_prime = Vector{typeof(t)}(undef, num_stages) #time stepping
589+
c_prime = Vector{typeof(t)}(undef, max) #time stepping
590+
for i in 1 : max
591+
c_prime[i] = zero(t)
592+
end
593593

594594
dw1 = zero(u)
595595
ubuff = zero(u)
596-
dw2 = [similar(u, Complex{eltype(u)}) for _ in 1 : (num_stages - 1) ÷ 2]
596+
dw2 = [similar(u, Complex{eltype(u)}) for _ in 1 : (max - 1) ÷ 2]
597597
recursivefill!.(dw2, false)
598-
cubuff = [similar(u, Complex{eltype(u)}) for _ in 1 : (num_stages - 1) ÷ 2]
598+
cubuff = [similar(u, Complex{eltype(u)}) for _ in 1 : (max - 1) ÷ 2]
599599
recursivefill!.(cubuff, false)
600-
dw = Vector{typeof(u)}(undef, num_stages - 1)
600+
dw = [zero(u) for i in 1 : max]
601601

602-
cont = Vector{typeof(u)}(undef, num_stages)
603-
for i in 1 : num_stages
604-
cont[i] = zero(u)
605-
end
602+
cont = [zero(u) for i in 1:max]
606603

607-
derivatives = Matrix{typeof(u)}(undef, num_stages, num_stages)
608-
for i in 1 : num_stages, j in 1 : num_stages
604+
derivatives = Matrix{typeof(u)}(undef, max, max)
605+
for i in 1 : max, j in 1 : max
609606
derivatives[i, j] = zero(u)
610607
end
611608

612609
fsalfirst = zero(rate_prototype)
613-
fw = Vector{typeof(rate_prototype)}(undef, num_stages)
614-
ks = Vector{typeof(rate_prototype)}(undef, num_stages)
615-
for i in 1: num_stages
616-
ks[i] = fw[i] = zero(rate_prototype)
617-
end
610+
fw = [zero(rate_prototype) for i in 1 : max]
611+
ks = [zero(rate_prototype) for i in 1 : max]
612+
618613
k = ks[1]
619614

620615
J, W1 = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true))
621616
if J isa AbstractSciMLOperator
622617
error("Non-concrete Jacobian not yet supported by AdaptiveRadau.")
623618
end
624619

625-
W2 = [similar(J, Complex{eltype(W1)}) for _ in 1 : (num_stages - 1) ÷ 2]
620+
W2 = [similar(J, Complex{eltype(W1)}) for _ in 1 : (max - 1) ÷ 2]
626621
recursivefill!.(W2, false)
627622

628623
du1 = zero(rate_prototype)
@@ -640,7 +635,7 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
640635

641636
linsolve2 = [
642637
init(LinearProblem(W2[i], _vec(cubuff[i]); u0 = _vec(dw2[i])), alg.linsolve, alias_A = true, alias_b = true,
643-
assumptions = LinearSolve.OperatorAssumptions(true)) for i in 1 : (num_stages - 1) ÷ 2]
638+
assumptions = LinearSolve.OperatorAssumptions(true)) for i in 1 : (max - 1) ÷ 2]
644639

645640
rtol = reltol isa Number ? reltol : zero(reltol)
646641
atol = reltol isa Number ? reltol : zero(reltol)
@@ -649,9 +644,9 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
649644
z, w, c_prime, dw1, ubuff, dw2, cubuff, dw, cont, derivatives,
650645
du1, fsalfirst, ks, k, fw,
651646
J, W1, W2,
652-
uf, tab, κ, one(uToltype), 10000, tmp,
647+
uf, tabs, κ, one(uToltype), 10000, tmp,
653648
atmp, jac_config,
654649
linsolve1, linsolve2, rtol, atol, dt, dt,
655-
Convergence, alg.step_limiter!)
650+
Convergence, alg.step_limiter!, num_stages, 1, 0.0)
656651
end
657652

0 commit comments

Comments
 (0)