Skip to content

Commit 733bfa5

Browse files
Merge pull request #2450 from Shreyas-Ekanathan/master
Implement Part 1 Of Adaptive Radau Method
2 parents e6ddc71 + bdb0a63 commit 733bfa5

File tree

9 files changed

+1211
-95
lines changed

9 files changed

+1211
-95
lines changed

lib/OrdinaryDiffEqFIRK/Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,20 @@ version = "1.1.1"
66
[deps]
77
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
88
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
9+
GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a"
10+
GenericSchur = "c145ed77-6b09-5dd9-b285-bf645a82121e"
911
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1012
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
1113
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
1214
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
1315
OrdinaryDiffEqDifferentiation = "4302a76b-040a-498a-8c04-15b101fed76b"
1416
OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8"
17+
Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45"
1518
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
1619
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
20+
RootedTrees = "47965b36-3f3e-11e9-0dcf-4570dfd42a8c"
1721
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
22+
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
1823

1924
[compat]
2025
DiffEqBase = "6.152.2"

lib/OrdinaryDiffEqFIRK/src/OrdinaryDiffEqFIRK.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!,
1818
get_current_adaptive_order, get_fsalfirstlast,
1919
isfirk, generic_solver_docstring
2020
using MuladdMacro, DiffEqBase, RecursiveArrayTools
21+
using Polynomials, GenericLinearAlgebra, GenericSchur
2122
using SciMLOperators: AbstractSciMLOperator
2223
using LinearAlgebra: I, UniformScaling, mul!, lu
2324
import LinearSolve
@@ -42,6 +43,6 @@ include("firk_tableaus.jl")
4243
include("firk_perform_step.jl")
4344
include("integrator_interface.jl")
4445

45-
export RadauIIA3, RadauIIA5, RadauIIA9
46+
export RadauIIA3, RadauIIA5, RadauIIA9, AdaptiveRadau
4647

4748
end

lib/OrdinaryDiffEqFIRK/src/alg_utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ qmax_default(alg::Union{RadauIIA3, RadauIIA5, RadauIIA9}) = 8
33
alg_order(alg::RadauIIA3) = 3
44
alg_order(alg::RadauIIA5) = 5
55
alg_order(alg::RadauIIA9) = 9
6+
alg_order(alg::AdaptiveRadau) = 5
67

78
isfirk(alg::RadauIIA3) = true
89
isfirk(alg::RadauIIA5) = true
910
isfirk(alg::RadauIIA9) = true
11+
isfirk(alg::AdaptiveRadau) = true
1012

1113
alg_adaptive_order(alg::RadauIIA3) = 1
1214
alg_adaptive_order(alg::RadauIIA5) = 3

lib/OrdinaryDiffEqFIRK/src/algorithms.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,42 @@ function RadauIIA9(; chunk_size = Val{0}(), autodiff = Val{true}(),
150150
controller,
151151
step_limiter!)
152152
end
153+
154+
struct AdaptiveRadau{CS, AD, F, P, FDT, ST, CJ, Tol, C1, C2, StepLimiter} <:
155+
OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ}
156+
linsolve::F
157+
precs::P
158+
smooth_est::Bool
159+
extrapolant::Symbol
160+
κ::Tol
161+
maxiters::Int
162+
fast_convergence_cutoff::C1
163+
new_W_γdt_cutoff::C2
164+
controller::Symbol
165+
step_limiter!::StepLimiter
166+
num_stages::Int
167+
end
168+
169+
function AdaptiveRadau(; chunk_size = Val{0}(), autodiff = Val{true}(),
170+
standardtag = Val{true}(), concrete_jac = nothing,
171+
diff_type = Val{:forward}, num_stages = 3,
172+
linsolve = nothing, precs = DEFAULT_PRECS,
173+
extrapolant = :dense, fast_convergence_cutoff = 1 // 5,
174+
new_W_γdt_cutoff = 1 // 5,
175+
controller = :Predictive, κ = nothing, maxiters = 10, smooth_est = true,
176+
step_limiter! = trivial_limiter!)
177+
AdaptiveRadau{_unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve),
178+
typeof(precs), diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac),
179+
typeof(κ), typeof(fast_convergence_cutoff),
180+
typeof(new_W_γdt_cutoff), typeof(step_limiter!)}(linsolve,
181+
precs,
182+
smooth_est,
183+
extrapolant,
184+
κ,
185+
maxiters,
186+
fast_convergence_cutoff,
187+
new_W_γdt_cutoff,
188+
controller,
189+
step_limiter!, num_stages)
190+
end
191+

lib/OrdinaryDiffEqFIRK/src/firk_caches.jl

Lines changed: 187 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ mutable struct RadauIIA9ConstantCache{F, Tab, Tol, Dt, U, JType} <:
287287
cont2::U
288288
cont3::U
289289
cont4::U
290+
cont5::U
290291
dtprev::Dt
291292
W_γdt::Dt
292293
status::NLStatus
@@ -304,7 +305,7 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
304305
κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)
305306
J = false .* _vec(rate_prototype) .* _vec(rate_prototype)'
306307

307-
RadauIIA9ConstantCache(uf, tab, κ, one(uToltype), 10000, u, u, u, u, dt, dt,
308+
RadauIIA9ConstantCache(uf, tab, κ, one(uToltype), 10000, u, u, u, u, u, dt, dt,
308309
Convergence, J)
309310
end
310311

@@ -333,6 +334,7 @@ mutable struct RadauIIA9Cache{uType, cuType, uNoUnitsType, rateType, JType, W1Ty
333334
cont2::uType
334335
cont3::uType
335336
cont4::uType
337+
cont5::uType
336338
du1::rateType
337339
fsalfirst::rateType
338340
k::rateType
@@ -407,6 +409,7 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
407409
cont2 = zero(u)
408410
cont3 = zero(u)
409411
cont4 = zero(u)
412+
cont5 = zero(u)
410413

411414
fsalfirst = zero(rate_prototype)
412415
k = zero(rate_prototype)
@@ -462,11 +465,193 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
462465

463466
RadauIIA9Cache(u, uprev,
464467
z1, z2, z3, z4, z5, w1, w2, w3, w4, w5,
465-
dw1, ubuff, dw23, dw45, cubuff1, cubuff2, cont1, cont2, cont3, cont4,
468+
dw1, ubuff, dw23, dw45, cubuff1, cubuff2, cont1, cont2, cont3, cont4, cont5,
466469
du1, fsalfirst, k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5,
467470
J, W1, W2, W3,
468471
uf, tab, κ, one(uToltype), 10000,
469472
tmp, tmp2, tmp3, tmp4, tmp5, tmp6, atmp, jac_config,
470473
linsolve1, linsolve2, linsolve3, rtol, atol, dt, dt,
471474
Convergence, alg.step_limiter!)
472475
end
476+
477+
mutable struct AdaptiveRadauConstantCache{F, Tab, Tol, Dt, U, JType} <:
478+
OrdinaryDiffEqConstantCache
479+
uf::F
480+
tab::Tab
481+
κ::Tol
482+
ηold::Tol
483+
iter::Int
484+
cont::Vector{U}
485+
dtprev::Dt
486+
W_γdt::Dt
487+
status::NLStatus
488+
J::JType
489+
end
490+
491+
function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits},
492+
::Type{uBottomEltypeNoUnits},
493+
::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck,
494+
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
495+
uf = UDerivativeWrapper(f, t, p)
496+
uToltype = constvalue(uBottomEltypeNoUnits)
497+
num_stages = alg.num_stages
498+
499+
if (num_stages == 3)
500+
tab = BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits))
501+
elseif (num_stages == 5)
502+
tab = BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits))
503+
elseif (num_stages == 7)
504+
tab = BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))
505+
elseif iseven(num_stages) || num_stages <3
506+
error("num_stages must be odd and 3 or greater")
507+
else
508+
tab = adaptiveRadauTableau(uToltype, constvalue(tTypeNoUnits), num_stages)
509+
end
510+
511+
cont = Vector{typeof(u)}(undef, num_stages)
512+
for i in 1: num_stages
513+
cont[i] = zero(u)
514+
end
515+
516+
κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)
517+
J = false .* _vec(rate_prototype) .* _vec(rate_prototype)'
518+
519+
AdaptiveRadauConstantCache(uf, tab, κ, one(uToltype), 10000, cont, dt, dt,
520+
Convergence, J)
521+
end
522+
523+
mutable struct AdaptiveRadauCache{uType, cuType, tType, uNoUnitsType, rateType, JType, W1Type, W2Type,
524+
UF, JC, F1, F2, Tab, Tol, Dt, rTol, aTol, StepLimiter} <:
525+
FIRKMutableCache
526+
u::uType
527+
uprev::uType
528+
z::Vector{uType}
529+
w::Vector{uType}
530+
c_prime::Vector{tType}
531+
dw1::uType
532+
ubuff::uType
533+
dw2::Vector{cuType}
534+
cubuff::Vector{cuType}
535+
dw::Vector{uType}
536+
cont::Vector{uType}
537+
derivatives:: Matrix{uType}
538+
du1::rateType
539+
fsalfirst::rateType
540+
ks::Vector{rateType}
541+
k::rateType
542+
fw::Vector{rateType}
543+
J::JType
544+
W1::W1Type #real
545+
W2::Vector{W2Type} #complex
546+
uf::UF
547+
tab::Tab
548+
κ::Tol
549+
ηold::Tol
550+
iter::Int
551+
tmp::uType
552+
atmp::uNoUnitsType
553+
jac_config::JC
554+
linsolve1::F1 #real
555+
linsolve2::Vector{F2} #complex
556+
rtol::rTol
557+
atol::aTol
558+
dtprev::Dt
559+
W_γdt::Dt
560+
status::NLStatus
561+
step_limiter!::StepLimiter
562+
end
563+
564+
function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits},
565+
::Type{uBottomEltypeNoUnits},
566+
::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck,
567+
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
568+
uf = UJacobianWrapper(f, t, p)
569+
uToltype = constvalue(uBottomEltypeNoUnits)
570+
num_stages = alg.num_stages
571+
572+
if (num_stages == 3)
573+
tab = BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits))
574+
elseif (num_stages == 5)
575+
tab = BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits))
576+
elseif (num_stages == 7)
577+
tab = BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))
578+
elseif iseven(num_stages) || num_stages < 3
579+
error("num_stages must be odd and 3 or greater")
580+
else
581+
tab = adaptiveRadauTableau(uToltype, constvalue(tTypeNoUnits), num_stages)
582+
end
583+
584+
κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)
585+
586+
z = Vector{typeof(u)}(undef, num_stages)
587+
w = Vector{typeof(u)}(undef, num_stages)
588+
for i in 1 : num_stages
589+
z[i] = w[i] = zero(u)
590+
end
591+
592+
c_prime = Vector{typeof(t)}(undef, num_stages) #time stepping
593+
594+
dw1 = zero(u)
595+
ubuff = zero(u)
596+
dw2 = [similar(u, Complex{eltype(u)}) for _ in 1 : (num_stages - 1) ÷ 2]
597+
recursivefill!.(dw2, false)
598+
cubuff = [similar(u, Complex{eltype(u)}) for _ in 1 : (num_stages - 1) ÷ 2]
599+
recursivefill!.(cubuff, false)
600+
dw = Vector{typeof(u)}(undef, num_stages - 1)
601+
602+
cont = Vector{typeof(u)}(undef, num_stages)
603+
for i in 1 : num_stages
604+
cont[i] = zero(u)
605+
end
606+
607+
derivatives = Matrix{typeof(u)}(undef, num_stages, num_stages)
608+
for i in 1 : num_stages, j in 1 : num_stages
609+
derivatives[i, j] = zero(u)
610+
end
611+
612+
fsalfirst = zero(rate_prototype)
613+
fw = Vector{typeof(rate_prototype)}(undef, num_stages)
614+
ks = Vector{typeof(rate_prototype)}(undef, num_stages)
615+
for i in 1: num_stages
616+
ks[i] = fw[i] = zero(rate_prototype)
617+
end
618+
k = ks[1]
619+
620+
J, W1 = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true))
621+
if J isa AbstractSciMLOperator
622+
error("Non-concrete Jacobian not yet supported by AdaptiveRadau.")
623+
end
624+
625+
W2 = [similar(J, Complex{eltype(W1)}) for _ in 1 : (num_stages - 1) ÷ 2]
626+
recursivefill!.(W2, false)
627+
628+
du1 = zero(rate_prototype)
629+
630+
tmp = zero(u)
631+
632+
atmp = similar(u, uEltypeNoUnits)
633+
recursivefill!(atmp, false)
634+
635+
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, zero(u), dw1)
636+
637+
linprob = LinearProblem(W1, _vec(ubuff); u0 = _vec(dw1))
638+
linsolve1 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
639+
assumptions = LinearSolve.OperatorAssumptions(true))
640+
641+
linsolve2 = [
642+
init(LinearProblem(W2[i], _vec(cubuff[i]); u0 = _vec(dw2[i])), alg.linsolve, alias_A = true, alias_b = true,
643+
assumptions = LinearSolve.OperatorAssumptions(true)) for i in 1 : (num_stages - 1) ÷ 2]
644+
645+
rtol = reltol isa Number ? reltol : zero(reltol)
646+
atol = reltol isa Number ? reltol : zero(reltol)
647+
648+
AdaptiveRadauCache(u, uprev,
649+
z, w, c_prime, dw1, ubuff, dw2, cubuff, dw, cont, derivatives,
650+
du1, fsalfirst, ks, k, fw,
651+
J, W1, W2,
652+
uf, tab, κ, one(uToltype), 10000, tmp,
653+
atmp, jac_config,
654+
linsolve1, linsolve2, rtol, atol, dt, dt,
655+
Convergence, alg.step_limiter!)
656+
end
657+

0 commit comments

Comments
 (0)