Skip to content

Commit 0054db2

Browse files
committed
Rosenbrock refactor Rodas5*
1 parent b0f957c commit 0054db2

File tree

7 files changed

+241
-1542
lines changed

7 files changed

+241
-1542
lines changed

lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ end
1111
function DiffEqBase.interp_summary(::Type{cacheType},
1212
dense::Bool) where {
1313
cacheType <:
14-
Union{Rodas4ConstantCache, Rodas23WConstantCache, Rodas3PConstantCache,
14+
Union{RosenbrockCombinedConstantCache, Rodas23WConstantCache, Rodas3PConstantCache,
1515
RosenbrockCache, Rodas23WCache, Rodas3PCache}}
1616
dense ? "specialized 3rd order \"free\" stiffness-aware interpolation" :
1717
"1st order linear"
@@ -20,8 +20,8 @@ end
2020
function DiffEqBase.interp_summary(::Type{cacheType},
2121
dense::Bool) where {
2222
cacheType <:
23-
Union{Rosenbrock5ConstantCache,
24-
Rosenbrock5Cache}}
23+
Union{RosenbrockCombinedConstantCache,
24+
RosenbrockCache}}
2525
dense ? "specialized 4rd order \"free\" stiffness-aware interpolation" :
2626
"1st order linear"
2727
end

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl

Lines changed: 27 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,24 @@ mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabT
4141
alg::A
4242
step_limiter!::StepLimiter
4343
stage_limiter!::StageLimiter
44+
order::Int
4445
end
4546
function full_cache(c::RosenbrockCache)
4647
return [c.u, c.uprev, c.dense..., c.du, c.du1, c.du2,
4748
c.ks..., c.fsalfirst, c.fsallast, c.dT, c.tmp, c.atmp, c.weight, c.linsolve_tmp]
4849
end
4950

51+
struct RosenbrockCombinedConstantCache{TF, UF, Tab, JType, WType, F, AD} <: RosenbrockConstantCache
52+
tf::TF
53+
uf::UF
54+
tab::Tab
55+
J::JType
56+
W::WType
57+
linsolve::F
58+
autodiff::AD
59+
order::Int
60+
end
61+
5062
@cache mutable struct Rosenbrock23Cache{uType, rateType, uNoUnitsType, JType, WType,
5163
TabType, TFType, UFType, F, JCType, GCType,
5264
RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache
@@ -702,16 +714,6 @@ end
702714

703715
### Rodas4 methods
704716

705-
struct Rodas4ConstantCache{TF, UF, Tab, JType, WType, F, AD} <: RosenbrockConstantCache
706-
tf::TF
707-
uf::UF
708-
tab::Tab
709-
J::JType
710-
W::WType
711-
linsolve::F
712-
autodiff::AD
713-
end
714-
715717
tabtype(::Rodas4) = Rodas4Tableau
716718
tabtype(::Rodas42) = Rodas42Tableau
717719
tabtype(::Rodas4P) = Rodas4PTableau
@@ -727,10 +729,10 @@ function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2},
727729
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false))
728730
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
729731
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
730-
Rodas4ConstantCache(tf, uf,
732+
RosenbrockCombinedConstantCache(tf, uf,
731733
tabtype(alg)(constvalue(uBottomEltypeNoUnits),
732734
constvalue(tTypeNoUnits)), J, W, linsolve,
733-
alg_autodiff(alg))
735+
alg_autodiff(alg), 4)
734736
end
735737

736738
function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2},
@@ -783,81 +785,22 @@ function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2},
783785
u, uprev, dense, du, du1, du2, ks, fsalfirst, fsallast,
784786
dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
785787
linsolve, jac_config, grad_config, reltol, alg,
786-
alg.step_limiter!, alg.stage_limiter!)
788+
alg.step_limiter!, alg.stage_limiter!, 4)
787789
end
788790

789791
################################################################################
790792

791793
### Rosenbrock5
792794

793-
struct Rosenbrock5ConstantCache{TF, UF, Tab, JType, WType, F} <: RosenbrockConstantCache
794-
tf::TF
795-
uf::UF
796-
tab::Tab
797-
J::JType
798-
W::WType
799-
linsolve::F
800-
end
801-
802-
@cache mutable struct Rosenbrock5Cache{
803-
uType, rateType, uNoUnitsType, JType, WType, TabType,
804-
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter} <:
805-
RosenbrockMutableCache
806-
u::uType
807-
uprev::uType
808-
dense1::rateType
809-
dense2::rateType
810-
dense3::rateType
811-
du::rateType
812-
du1::rateType
813-
du2::rateType
814-
k1::rateType
815-
k2::rateType
816-
k3::rateType
817-
k4::rateType
818-
k5::rateType
819-
k6::rateType
820-
k7::rateType
821-
k8::rateType
822-
fsalfirst::rateType
823-
fsallast::rateType
824-
dT::rateType
825-
J::JType
826-
W::WType
827-
tmp::rateType
828-
atmp::uNoUnitsType
829-
weight::uNoUnitsType
830-
tab::TabType
831-
tf::TFType
832-
uf::UFType
833-
linsolve_tmp::rateType
834-
linsolve::F
835-
jac_config::JCType
836-
grad_config::GCType
837-
reltol::RTolType
838-
alg::A
839-
step_limiter!::StepLimiter
840-
stage_limiter!::StageLimiter
841-
end
842-
843795
function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
844796
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
845797
dt, reltol, p, calck,
846798
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
847-
dense1 = zero(rate_prototype)
848-
dense2 = zero(rate_prototype)
849-
dense3 = zero(rate_prototype)
799+
dense = [zero(rate_prototype) for _ in 1:3]
850800
du = zero(rate_prototype)
851801
du1 = zero(rate_prototype)
852802
du2 = zero(rate_prototype)
853-
k1 = zero(rate_prototype)
854-
k2 = zero(rate_prototype)
855-
k3 = zero(rate_prototype)
856-
k4 = zero(rate_prototype)
857-
k5 = zero(rate_prototype)
858-
k6 = zero(rate_prototype)
859-
k7 = zero(rate_prototype)
860-
k8 = zero(rate_prototype)
803+
ks = [zero(rate_prototype) for _ in 1:7]
861804
fsalfirst = zero(rate_prototype)
862805
fsallast = zero(rate_prototype)
863806
dT = zero(rate_prototype)
@@ -881,12 +824,11 @@ function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
881824
assumptions = LinearSolve.OperatorAssumptions(true))
882825
grad_config = build_grad_config(alg, f, tf, du1, t)
883826
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
884-
Rosenbrock5Cache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4,
885-
k5, k6, k7, k8,
827+
RosenbrockCache(u, uprev, dense, du, du1, du2, ks,
886828
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
887829
linsolve_tmp,
888830
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
889-
alg.stage_limiter!)
831+
alg.stage_limiter!, 5)
890832
end
891833

892834
function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -898,30 +840,21 @@ function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
898840
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false))
899841
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
900842
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
901-
Rosenbrock5ConstantCache(tf, uf,
843+
RosenbrockCombinedConstantCache(tf, uf,
902844
Rodas5Tableau(constvalue(uBottomEltypeNoUnits),
903-
constvalue(tTypeNoUnits)), J, W, linsolve)
845+
constvalue(tTypeNoUnits)), J, W, linsolve, alg_autodiff(alg), 5)
904846
end
905847

906848
function alg_cache(
907849
alg::Union{Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits},
908850
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
909851
dt, reltol, p, calck,
910852
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
911-
dense1 = zero(rate_prototype)
912-
dense2 = zero(rate_prototype)
913-
dense3 = zero(rate_prototype)
853+
dense = [zero(rate_prototype) for _ in 1:3]
914854
du = zero(rate_prototype)
915855
du1 = zero(rate_prototype)
916856
du2 = zero(rate_prototype)
917-
k1 = zero(rate_prototype)
918-
k2 = zero(rate_prototype)
919-
k3 = zero(rate_prototype)
920-
k4 = zero(rate_prototype)
921-
k5 = zero(rate_prototype)
922-
k6 = zero(rate_prototype)
923-
k7 = zero(rate_prototype)
924-
k8 = zero(rate_prototype)
857+
ks = [zero(rate_prototype) for _ in 1:8]
925858
fsalfirst = zero(rate_prototype)
926859
fsallast = zero(rate_prototype)
927860
dT = zero(rate_prototype)
@@ -945,12 +878,11 @@ function alg_cache(
945878
assumptions = LinearSolve.OperatorAssumptions(true))
946879
grad_config = build_grad_config(alg, f, tf, du1, t)
947880
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
948-
Rosenbrock5Cache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4,
949-
k5, k6, k7, k8,
881+
RosenbrockCache(u, uprev, dense, du, du1, du2, ks,
950882
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
951883
linsolve_tmp,
952884
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
953-
alg.stage_limiter!)
885+
alg.stage_limiter!, 5)
954886
end
955887

956888
function alg_cache(
@@ -963,9 +895,9 @@ function alg_cache(
963895
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false))
964896
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
965897
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
966-
Rosenbrock5ConstantCache(tf, uf,
898+
RosenbrockCombinedConstantCache(tf, uf,
967899
Rodas5PTableau(constvalue(uBottomEltypeNoUnits),
968-
constvalue(tTypeNoUnits)), J, W, linsolve)
900+
constvalue(tTypeNoUnits)), J, W, linsolve, alg_autodiff(alg), 5)
969901
end
970902

971903
function get_fsalfirstlast(

0 commit comments

Comments
 (0)