Skip to content

Commit 5c00abb

Browse files
Merge pull request #2548 from Shreyas-Ekanathan/master
Small Radau Optimizations
2 parents a6b4262 + c069898 commit 5c00abb

File tree

4 files changed

+51
-42
lines changed

4 files changed

+51
-42
lines changed

lib/OrdinaryDiffEqFIRK/src/controllers.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
function step_accept_controller!(integrator, controller::PredictiveController, alg::AdaptiveRadau, q)
22
@unpack qmin, qmax, gamma, qsteady_min, qsteady_max = integrator.opts
33
@unpack cache = integrator
4-
@unpack num_stages, step, iter, hist_iter = cache
4+
@unpack num_stages, step, iter, hist_iter, index = cache
55

66
EEst = DiffEqBase.value(integrator.EEst)
77

@@ -25,12 +25,14 @@ function step_accept_controller!(integrator, controller::PredictiveController, a
2525
max_stages = (alg.max_order - 1) ÷ 4 * 2 + 1
2626
min_stages = (alg.min_order - 1) ÷ 4 * 2 + 1
2727
if (step > 10)
28-
if (hist_iter < 2.6 && num_stages <= max_stages)
28+
if (hist_iter < 2.6 && num_stages < max_stages)
2929
cache.num_stages += 2
30+
cache.index += 1
3031
cache.step = 1
3132
cache.hist_iter = iter
32-
elseif ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages >= min_stages)
33+
elseif ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > min_stages)
3334
cache.num_stages -= 2
35+
cache.index -= 1
3436
cache.step = 1
3537
cache.hist_iter = iter
3638
end
@@ -48,8 +50,9 @@ function step_reject_controller!(integrator, controller::PredictiveController, a
4850
cache.hist_iter = hist_iter
4951
min_stages = (alg.min_order - 1) ÷ 4 * 2 + 1
5052
if (step > 10)
51-
if ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages >= min_stages)
53+
if ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > min_stages)
5254
cache.num_stages -= 2
55+
cache.index -= 1
5356
cache.step = 1
5457
cache.hist_iter = iter
5558
end

lib/OrdinaryDiffEqFIRK/src/firk_caches.jl

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ mutable struct AdaptiveRadauConstantCache{F, Tab, Tol, Dt, U, JType} <:
497497
num_stages::Int
498498
step::Int
499499
hist_iter::Float64
500+
index::Int
500501
end
501502

502503
function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -508,30 +509,32 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
508509

509510
max_order = alg.max_order
510511
min_order = alg.min_order
511-
max = (max_order - 1) ÷ 4 * 2 + 1
512-
min = (min_order - 1) ÷ 4 * 2 + 1
512+
max_stages = (max_order - 1) ÷ 4 * 2 + 1
513+
min_stages = (min_order - 1) ÷ 4 * 2 + 1
513514
if (alg.min_order < 5)
514515
error("min_order choice $min_order below 5 is not compatible with the algorithm")
515-
elseif (max < min)
516+
elseif (max_stages < min_stages)
516517
error("max_order $max_order is below min_order $min_order")
517518
end
518-
num_stages = min
519+
num_stages = min_stages
519520

520521
tabs = [RadauIIATableau5(uToltype, constvalue(tTypeNoUnits)), RadauIIATableau9(uToltype, constvalue(tTypeNoUnits)), RadauIIATableau13(uToltype, constvalue(tTypeNoUnits))]
521-
i = 9
522-
while i <= max
522+
i = max(min_stages, 9)
523+
while i <= max_stages
523524
push!(tabs, RadauIIATableau(uToltype, constvalue(tTypeNoUnits), i))
524525
i += 2
525526
end
526-
cont = Vector{typeof(u)}(undef, max)
527-
for i in 1:max
527+
cont = Vector{typeof(u)}(undef, max_stages)
528+
for i in 1:max_stages
528529
cont[i] = zero(u)
529530
end
530531

532+
index = min((min_stages - 1) ÷ 2, 4)
533+
531534
κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)
532535
J = false .* _vec(rate_prototype) .* _vec(rate_prototype)'
533536
AdaptiveRadauConstantCache(uf, tabs, κ, one(uToltype), 10000, cont, dt, dt,
534-
Convergence, J, num_stages, 1, 0.0)
537+
Convergence, J, num_stages, 1, 0.0, index)
535538
end
536539

537540
mutable struct AdaptiveRadauCache{uType, cuType, tType, uNoUnitsType, rateType, JType, W1Type, W2Type,
@@ -578,6 +581,7 @@ mutable struct AdaptiveRadauCache{uType, cuType, tType, uNoUnitsType, rateType,
578581
num_stages::Int
579582
step::Int
580583
hist_iter::Float64
584+
index::Int
581585
end
582586

583587
function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -589,56 +593,58 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
589593

590594
max_order = alg.max_order
591595
min_order = alg.min_order
592-
max = (max_order - 1) ÷ 4 * 2 + 1
593-
min = (min_order - 1) ÷ 4 * 2 + 1
596+
max_stages = (max_order - 1) ÷ 4 * 2 + 1
597+
min_stages = (min_order - 1) ÷ 4 * 2 + 1
594598
if (alg.min_order < 5)
595599
error("min_order choice $min_order below 5 is not compatible with the algorithm")
596-
elseif (max < min)
600+
elseif (max_stages < min_stages)
597601
error("max_order $max_order is below min_order $min_order")
598602
end
599-
num_stages = min
603+
num_stages = min_stages
600604

601605
tabs = [RadauIIATableau5(uToltype, constvalue(tTypeNoUnits)), RadauIIATableau9(uToltype, constvalue(tTypeNoUnits)), RadauIIATableau13(uToltype, constvalue(tTypeNoUnits))]
602-
i = 9
603-
while i <= max
606+
i = max(min_stages, 9)
607+
while i <= max_stages
604608
push!(tabs, RadauIIATableau(uToltype, constvalue(tTypeNoUnits), i))
605609
i += 2
606610
end
607611

612+
index = min((min_stages - 1) ÷ 2, 4)
613+
608614
κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)
609615

610-
z = Vector{typeof(u)}(undef, max)
611-
w = Vector{typeof(u)}(undef, max)
612-
for i in 1 : max
616+
z = Vector{typeof(u)}(undef, max_stages)
617+
w = Vector{typeof(u)}(undef, max_stages)
618+
for i in 1 : max_stages
613619
z[i] = zero(u)
614620
w[i] = zero(u)
615621
end
616622

617-
αdt = [zero(t) for i in 1:max]
618-
βdt = [zero(t) for i in 1:max]
619-
c_prime = Vector{typeof(t)}(undef, max) #time stepping
620-
for i in 1 : max
623+
αdt = [zero(t) for i in 1:max_stages]
624+
βdt = [zero(t) for i in 1:max_stages]
625+
c_prime = Vector{typeof(t)}(undef, max_stages) #time stepping
626+
for i in 1 : max_stages
621627
c_prime[i] = zero(t)
622628
end
623629

624630
dw1 = zero(u)
625631
ubuff = zero(u)
626-
dw2 = [similar(u, Complex{eltype(u)}) for _ in 1 : (max - 1) ÷ 2]
632+
dw2 = [similar(u, Complex{eltype(u)}) for _ in 1 : (max_stages - 1) ÷ 2]
627633
recursivefill!.(dw2, false)
628-
cubuff = [similar(u, Complex{eltype(u)}) for _ in 1 : (max - 1) ÷ 2]
634+
cubuff = [similar(u, Complex{eltype(u)}) for _ in 1 : (max_stages - 1) ÷ 2]
629635
recursivefill!.(cubuff, false)
630-
dw = [zero(u) for i in 1 : max]
636+
dw = [zero(u) for i in 1:max_stages]
631637

632-
cont = [zero(u) for i in 1:max]
638+
cont = [zero(u) for i in 1:max_stages]
633639

634-
derivatives = Matrix{typeof(u)}(undef, max, max)
635-
for i in 1 : max, j in 1 : max
640+
derivatives = Matrix{typeof(u)}(undef, max_stages, max_stages)
641+
for i in 1 : max_stages, j in 1 : max_stages
636642
derivatives[i, j] = zero(u)
637643
end
638644

639645
fsalfirst = zero(rate_prototype)
640-
fw = [zero(rate_prototype) for i in 1 : max]
641-
ks = [zero(rate_prototype) for i in 1 : max]
646+
fw = [zero(rate_prototype) for i in 1 : max_stages]
647+
ks = [zero(rate_prototype) for i in 1 : max_stages]
642648

643649
k = ks[1]
644650

@@ -647,7 +653,7 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
647653
error("Non-concrete Jacobian not yet supported by AdaptiveRadau.")
648654
end
649655

650-
W2 = [similar(J, Complex{eltype(W1)}) for _ in 1 : (max - 1) ÷ 2]
656+
W2 = [similar(J, Complex{eltype(W1)}) for _ in 1 : (max_stages - 1) ÷ 2]
651657
recursivefill!.(W2, false)
652658

653659
du1 = zero(rate_prototype)
@@ -665,7 +671,7 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
665671

666672
linsolve2 = [
667673
init(LinearProblem(W2[i], _vec(cubuff[i]); u0 = _vec(dw2[i])), alg.linsolve, alias_A = true, alias_b = true,
668-
assumptions = LinearSolve.OperatorAssumptions(true)) for i in 1 : (max - 1) ÷ 2]
674+
assumptions = LinearSolve.OperatorAssumptions(true)) for i in 1 : (max_stages - 1) ÷ 2]
669675

670676
rtol = reltol isa Number ? reltol : zero(reltol)
671677
atol = reltol isa Number ? reltol : zero(reltol)
@@ -677,6 +683,6 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
677683
uf, tabs, κ, one(uToltype), 10000, tmp,
678684
atmp, jac_config,
679685
linsolve1, linsolve2, rtol, atol, dt, dt,
680-
Convergence, alg.step_limiter!, num_stages, 1, 0.0)
686+
Convergence, alg.step_limiter!, num_stages, 1, 0.0, index)
681687
end
682688

lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,8 +1354,8 @@ end
13541354
@muladd function perform_step!(integrator, cache::AdaptiveRadauConstantCache,
13551355
repeat_step = false)
13561356
@unpack t, dt, uprev, u, f, p = integrator
1357-
@unpack tabs, num_stages = cache
1358-
tab = tabs[(num_stages - 1) ÷ 2]
1357+
@unpack tabs, num_stages, index = cache
1358+
tab = tabs[index]
13591359
@unpack T, TI, γ, α, β, c, e = tab
13601360
@unpack κ, cont = cache
13611361
@unpack internalnorm, abstol, reltol, adaptive = integrator.opts
@@ -1595,8 +1595,8 @@ end
15951595

15961596
@muladd function perform_step!(integrator, cache::AdaptiveRadauCache, repeat_step = false)
15971597
@unpack t, dt, uprev, u, f, p, fsallast, fsalfirst = integrator
1598-
@unpack num_stages, tabs = cache
1599-
tab = tabs[(num_stages - 1) ÷ 2]
1598+
@unpack num_stages, tabs, index = cache
1599+
tab = tabs[index]
16001600
@unpack T, TI, γ, α, β, c, e = tab
16011601
@unpack κ, cont, derivatives, z, w, c_prime, αdt, βdt= cache
16021602
@unpack dw1, ubuff, dw2, cubuff, dw = cache

lib/OrdinaryDiffEqFIRK/test/ode_high_order_firk_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ testTol = 0.5
77
prob_ode_linear_big = remake(prob_ode_linear, u0 = big.(prob_ode_linear.u0), tspan = big.(prob_ode_linear.tspan))
88
prob_ode_2Dlinear_big = remake(prob_ode_2Dlinear, u0 = big.(prob_ode_2Dlinear.u0), tspan = big.(prob_ode_2Dlinear.tspan))
99

10-
for i in [17, 21], prob in [prob_ode_linear_big, prob_ode_2Dlinear_big]
10+
for i in [17, 21, 25], prob in [prob_ode_linear_big, prob_ode_2Dlinear_big]
1111
dts = 1 ./ 2 .^ (4.25:-1:0.25)
1212
sim21 = test_convergence(dts, prob, AdaptiveRadau(min_order = i, max_order = i))
1313
@test sim21.𝒪est[:final] i atol=testTol

0 commit comments

Comments
 (0)