Skip to content

Commit 937641f

Browse files
speed ups
1 parent 9ec5720 commit 937641f

File tree

4 files changed

+75
-52
lines changed

4 files changed

+75
-52
lines changed

lib/OrdinaryDiffEqFIRK/src/algorithms.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,8 @@ struct AdaptiveRadau{CS, AD, F, P, FDT, ST, CJ, Tol, C1, C2, StepLimiter} <:
163163
new_W_γdt_cutoff::C2
164164
controller::Symbol
165165
step_limiter!::StepLimiter
166-
min_stages::Int
167-
max_stages::Int
166+
min_order::Int
167+
max_order::Int
168168
end
169169

170170
function AdaptiveRadau(; chunk_size = Val{0}(), autodiff = Val{true}(),

lib/OrdinaryDiffEqFIRK/src/controllers.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@ function step_accept_controller!(integrator, controller::PredictiveController, a
2222
cache.step = step + 1
2323
hist_iter = hist_iter * 0.8 + iter * 0.2
2424
cache.hist_iter = hist_iter
25+
max_stages = (alg.max_order - 1) ÷ 4 * 2 + 1
26+
min_stages = (alg.min_order - 1) ÷ 4 * 2 + 1
2527
if (step > 10)
26-
if (hist_iter < 2.6 && num_stages < (alg.max_order + 1) ÷ 2)
28+
if (hist_iter < 2.6 && num_stages <= max_stages)
2729
cache.num_stages += 2
2830
cache.step = 1
2931
cache.hist_iter = iter
30-
elseif ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > (alg.min_order + 1) ÷ 2)
32+
elseif ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages >= min_stages)
3133
cache.num_stages -= 2
3234
cache.step = 1
3335
cache.hist_iter = iter
@@ -44,8 +46,9 @@ function step_reject_controller!(integrator, controller::PredictiveController, a
4446
cache.step = step + 1
4547
hist_iter = hist_iter * 0.8 + iter * 0.2
4648
cache.hist_iter = hist_iter
49+
min_stages = (alg.min_order - 1) ÷ 4 * 2 + 1
4750
if (step > 10)
48-
if ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > (alg.min_order + 1) ÷ 2)
51+
if ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages >= min_stages)
4952
cache.num_stages -= 2
5053
cache.step = 1
5154
cache.hist_iter = iter

lib/OrdinaryDiffEqFIRK/src/firk_caches.jl

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,10 @@ mutable struct RadauIIA9Cache{uType, cuType, uNoUnitsType, rateType, JType, W1Ty
362362
tmp4::uType
363363
tmp5::uType
364364
tmp6::uType
365+
tmp7::uType
366+
tmp8::uType
367+
tmp9::uType
368+
tmp10::uType
365369
atmp::uNoUnitsType
366370
jac_config::JC
367371
linsolve1::F1
@@ -440,6 +444,10 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
440444
tmp4 = zero(u)
441445
tmp5 = zero(u)
442446
tmp6 = zero(u)
447+
tmp7 = zero(u)
448+
tmp8 = zero(u)
449+
tmp9 = zero(u)
450+
tmp10 = zero(u)
443451
atmp = similar(u, uEltypeNoUnits)
444452
recursivefill!(atmp, false)
445453
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw1)
@@ -469,7 +477,7 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
469477
du1, fsalfirst, k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5,
470478
J, W1, W2, W3,
471479
uf, tab, κ, one(uToltype), 10000,
472-
tmp, tmp2, tmp3, tmp4, tmp5, tmp6, atmp, jac_config,
480+
tmp, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, atmp, jac_config,
473481
linsolve1, linsolve2, linsolve3, rtol, atol, dt, dt,
474482
Convergence, alg.step_limiter!)
475483
end
@@ -497,17 +505,26 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
497505
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
498506
uf = UDerivativeWrapper(f, t, p)
499507
uToltype = constvalue(uBottomEltypeNoUnits)
500-
max = (alg.max_order + 1) ÷ 2
501-
num_stages = (alg.min_order + 1) ÷ 2
508+
509+
max_order = alg.max_order
510+
min_order = alg.min_order
511+
max = (max_order - 1) ÷ 4 * 2 + 1
512+
min = (min_order - 1) ÷ 4 * 2 + 1
513+
if (alg.min_order < 5)
514+
error("min_order choice $min_order below 5 is not compatible with the algorithm")
515+
elseif (max < min)
516+
error("max_order $max_order is below min_order $min_order")
517+
end
518+
num_stages = min
519+
502520
tabs = [BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))]
503-
504521
i = 9
505522
while i <= max
506523
push!(tabs, adaptiveRadauTableau(uToltype, constvalue(tTypeNoUnits), i))
507524
i += 2
508525
end
509526
cont = Vector{typeof(u)}(undef, max)
510-
for i in 1: max
527+
for i in 1:max
511528
cont[i] = zero(u)
512529
end
513530

@@ -570,8 +587,16 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
570587
uf = UJacobianWrapper(f, t, p)
571588
uToltype = constvalue(uBottomEltypeNoUnits)
572589

573-
max = (alg.max_order + 1) ÷ 2
574-
num_stages = (alg.min_order + 1) ÷ 2
590+
max_order = alg.max_order
591+
min_order = alg.min_order
592+
max = (max_order - 1) ÷ 4 * 2 + 1
593+
min = (min_order - 1) ÷ 4 * 2 + 1
594+
if (alg.min_order < 5)
595+
error("min_order choice $min_order below 5 is not compatible with the algorithm")
596+
elseif (max < min)
597+
error("max_order $max_order is below min_order $min_order")
598+
end
599+
num_stages = min
575600

576601
tabs = [BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))]
577602
i = 9

lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,7 +1032,7 @@ end
10321032
@unpack dw1, ubuff, dw23, dw45, cubuff1, cubuff2 = cache
10331033
@unpack k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5 = cache
10341034
@unpack J, W1, W2, W3 = cache
1035-
@unpack tmp, tmp2, tmp3, tmp4, tmp5, tmp6, atmp, jac_config, linsolve1, linsolve2, linsolve3, rtol, atol, step_limiter! = cache
1035+
@unpack tmp, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, atmp, jac_config, linsolve1, linsolve2, linsolve3, rtol, atol, step_limiter! = cache
10361036
@unpack internalnorm, abstol, reltol, adaptive = integrator.opts
10371037
alg = unwrap_alg(integrator, true)
10381038
@unpack maxiters = alg
@@ -1087,30 +1087,30 @@ end
10871087
c2′ = c2 * c5′
10881088
c3′ = c3 * c5′
10891089
c4′ = c4 * c5′
1090-
z1 = @.. c1′ * (cont1 +
1090+
@.. z1 = c1′ * (cont1 +
10911091
(c1′-c4m1) * (cont2 +
10921092
(c1′ - c3m1) * (cont3 +
10931093
(c1′ - c2m1) * (cont4 + (c1′ - c1m1) * cont5))))
1094-
z2 = @.. c2′ * (cont1 +
1094+
@.. z2 = c2′ * (cont1 +
10951095
(c2′-c4m1) * (cont2 +
10961096
(c2′ - c3m1) * (cont3 +
10971097
(c2′ - c2m1) * (cont4 + (c2′ - c1m1) * cont5))))
1098-
z3 = @.. c3′ * (cont1 +
1098+
@.. z3 = c3′ * (cont1 +
10991099
(c3′-c4m1) * (cont2 +
11001100
(c3′ - c3m1) * (cont3 +
11011101
(c3′ - c2m1) * (cont4 + (c3′ - c1m1) * cont5))))
1102-
z4 = @.. c4′ * (cont1 +
1102+
@.. z4 = c4′ * (cont1 +
11031103
(c4′-c4m1) * (cont2 +
11041104
(c4′ - c3m1) * (cont3 +
11051105
(c4′ - c2m1) * (cont4 + (c4′ - c1m1) * cont5))))
1106-
z5 = @.. c5′ * (cont1 +
1106+
@.. z5 = c5′ * (cont1 +
11071107
(c5′-c4m1) * (cont2 +
11081108
(c5′ - c3m1) * (cont3 + (c5′ - c2m1) * (cont4 + (c5′ - c1m1) * cont5))))
1109-
w1 = @.. broadcast=false TI11*z1+TI12*z2+TI13*z3+TI14*z4+TI15*z5
1110-
w2 = @.. broadcast=false TI21*z1+TI22*z2+TI23*z3+TI24*z4+TI25*z5
1111-
w3 = @.. broadcast=false TI31*z1+TI32*z2+TI33*z3+TI34*z4+TI35*z5
1112-
w4 = @.. broadcast=false TI41*z1+TI42*z2+TI43*z3+TI44*z4+TI45*z5
1113-
w5 = @.. broadcast=false TI51*z1+TI52*z2+TI53*z3+TI54*z4+TI55*z5
1109+
@.. w1 = TI11*z1+TI12*z2+TI13*z3+TI14*z4+TI15*z5
1110+
@.. w2 = TI21*z1+TI22*z2+TI23*z3+TI24*z4+TI25*z5
1111+
@.. w3 = TI31*z1+TI32*z2+TI33*z3+TI34*z4+TI35*z5
1112+
@.. w4 = TI41*z1+TI42*z2+TI43*z3+TI44*z4+TI45*z5
1113+
@.. w5 = TI51*z1+TI52*z2+TI53*z3+TI54*z4+TI55*z5
11141114
end
11151115

11161116
# Newton iteration
@@ -1328,21 +1328,21 @@ end
13281328
if integrator.EEst <= oneunit(integrator.EEst)
13291329
cache.dtprev = dt
13301330
if alg.extrapolant != :constant
1331-
cache.cont1 = @.. (z4 - z5) / c4m1 # first derivative on [c4, 1]
1332-
tmp1 = @.. (z3 - z4) / c3mc4 # first derivative on [c3, c4]
1333-
cache.cont2 = @.. (tmp1 - cache.cont1) / c3m1 # second derivative on [c3, 1]
1334-
tmp2 = @.. (z2 - z3) / c2mc3 # first derivative on [c2, c3]
1335-
tmp3 = @.. (tmp2 - tmp1) / c2mc4 # second derivative on [c2, c4]
1336-
cache.cont3 = @.. (tmp3 - cache.cont2) / c2m1 # third derivative on [c2, 1]
1337-
tmp4 = @.. (z1 - z2) / c1mc2 # first derivative on [c1, c2]
1338-
tmp5 = @.. (tmp4 - tmp2) / c1mc3 # second derivative on [c1, c3]
1339-
tmp6 = @.. (tmp5 - tmp3) / c1mc4 # third derivative on [c1, c4]
1340-
cache.cont4 = @.. (tmp6 - cache.cont3) / c1m1 #fourth derivative on [c1, 1]
1341-
tmp7 = @.. z1 / c1 #first derivative on [0, c1]
1342-
tmp8 = @.. (tmp4 - tmp7) / c2 #second derivative on [0, c2]
1343-
tmp9 = @.. (tmp5 - tmp8) / c3 #third derivative on [0, c3]
1344-
tmp10 = @.. (tmp6 - tmp9) / c4 #fourth derivative on [0,c4]
1345-
cache.cont5 = @.. cache.cont4 - tmp10 #fifth derivative on [0,1]
1331+
@.. cache.cont1 = (z4 - z5) / c4m1 # first derivative on [c4, 1]
1332+
@.. tmp = (z3 - z4) / c3mc4 # first derivative on [c3, c4]
1333+
@.. cache.cont2 = (tmp - cache.cont1) / c3m1 # second derivative on [c3, 1]
1334+
@.. tmp2 = (z2 - z3) / c2mc3 # first derivative on [c2, c3]
1335+
@.. tmp3 = (tmp2 - tmp) / c2mc4 # second derivative on [c2, c4]
1336+
@.. cache.cont3 = (tmp3 - cache.cont2) / c2m1 # third derivative on [c2, 1]
1337+
@.. tmp4 = (z1 - z2) / c1mc2 # first derivative on [c1, c2]
1338+
@.. tmp5 = (tmp4 - tmp2) / c1mc3 # second derivative on [c1, c3]
1339+
@.. tmp6 = (tmp5 - tmp3) / c1mc4 # third derivative on [c1, c4]
1340+
@.. cache.cont4 = (tmp6 - cache.cont3) / c1m1 #fourth derivative on [c1, 1]
1341+
@.. tmp7 = z1 / c1 #first derivative on [0, c1]
1342+
@.. tmp8 = (tmp4 - tmp7) / c2 #second derivative on [0, c2]
1343+
@.. tmp9 = (tmp5 - tmp8) / c3 #third derivative on [0, c3]
1344+
@.. tmp10 = (tmp6 - tmp9) / c4 #fourth derivative on [0,c4]
1345+
@.. cache.cont5 = cache.cont4 - tmp10 #fifth derivative on [0,1]
13461346
end
13471347
end
13481348

@@ -1437,7 +1437,7 @@ end
14371437
for i in 1 : num_stages
14381438
z[i] = f(uprev + z[i], p, t + c[i] * dt)
14391439
end
1440-
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 5)
1440+
OrdinaryDiffEqCore.increment_nf!(integrator.stats, num_stages)
14411441

14421442
#fw = TI * ff
14431443
fw = Vector{typeof(u)}(undef, num_stages)
@@ -1619,7 +1619,7 @@ end
16191619
if (new_W = do_newW(integrator, alg, new_jac, cache.W_γdt))
16201620
@inbounds for II in CartesianIndices(J)
16211621
W1[II] = -γdt * mass_matrix[Tuple(II)...] + J[II]
1622-
for i in 1 :(num_stages - 1) ÷ 2
1622+
for i in 1 : (num_stages - 1) ÷ 2
16231623
W2[i][II] = -(αdt[i] + βdt[i] * im) * mass_matrix[Tuple(II)...] + J[II]
16241624
end
16251625
end
@@ -1673,7 +1673,7 @@ end
16731673
@.. tmp = uprev + z[i]
16741674
f(ks[i], tmp, p, t + c[i] * dt)
16751675
end
1676-
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 5)
1676+
OrdinaryDiffEqCore.increment_nf!(integrator.stats, num_stages)
16771677

16781678
#mul!(fw, TI, ks)
16791679
for i in 1:num_stages
@@ -1700,15 +1700,12 @@ end
17001700
@.. ubuff = fw[1] - γdt * Mw[1]
17011701
needfactor = iter == 1 && new_W
17021702

1703-
linsolve1 = cache.linsolve1
17041703
if needfactor
1705-
linres = dolinsolve(integrator, linsolve1; A = W1, b = _vec(ubuff), linu = _vec(dw1))
1704+
cache.linsolve1 = dolinsolve(integrator, linsolve1; A = W1, b = _vec(ubuff), linu = _vec(dw1)).cache
17061705
else
1707-
linres = dolinsolve(integrator, linsolve1; A = nothing, b = _vec(ubuff), linu = _vec(dw1))
1706+
cache.linsolve1 = dolinsolve(integrator, linsolve1; A = nothing, b = _vec(ubuff), linu = _vec(dw1)).cache
17081707
end
17091708

1710-
cache.linsolve1 = linres.cache
1711-
17121709
for i in 1 :(num_stages - 1) ÷ 2
17131710
@.. cubuff[i]=complex(
17141711
fw[2 * i] - αdt[i] * Mw[2 * i] + βdt[i] * Mw[2 * i + 1], fw[2 * i + 1] - βdt[i] * Mw[2 * i] - αdt[i] * Mw[2 * i + 1])
@@ -1801,9 +1798,8 @@ end
18011798
@.. broadcast=false ubuff=integrator.fsalfirst + tmp
18021799

18031800
if alg.smooth_est
1804-
linres = dolinsolve(integrator, linres.cache; b = _vec(ubuff),
1805-
linu = _vec(utilde))
1806-
cache.linsolve1 = linres.cache
1801+
cache.linsolve1 = dolinsolve(integrator, linsolve1; b = _vec(ubuff),
1802+
linu = _vec(utilde)).cache
18071803
integrator.stats.nsolve += 1
18081804
end
18091805

@@ -1821,9 +1817,8 @@ end
18211817
@.. broadcast=false ubuff=fsallast + tmp
18221818

18231819
if alg.smooth_est
1824-
linres = dolinsolve(integrator, linres.cache; b = _vec(ubuff),
1825-
linu = _vec(utilde))
1826-
cache.linsolve1 = linres.cache
1820+
cache.linsolve1 = dolinsolve(integrator, linsolve1; b = _vec(ubuff),
1821+
linu = _vec(utilde)).cache
18271822
integrator.stats.nsolve += 1
18281823
end
18291824

0 commit comments

Comments
 (0)