Skip to content

Commit f17e149

Browse files
authored
Perform_step! refactor for Rodas5*
2 parents 2fa672b + 0dfb039 commit f17e149

File tree

7 files changed

+253
-1718
lines changed

7 files changed

+253
-1718
lines changed

lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ end
1111
function DiffEqBase.interp_summary(::Type{cacheType},
1212
dense::Bool) where {
1313
cacheType <:
14-
Union{Rodas4ConstantCache, Rodas23WConstantCache, Rodas3PConstantCache,
14+
Union{RosenbrockCombinedConstantCache, Rodas23WConstantCache, Rodas3PConstantCache,
1515
RosenbrockCache, Rodas23WCache, Rodas3PCache}}
1616
dense ? "specialized 3rd order \"free\" stiffness-aware interpolation" :
1717
"1st order linear"
@@ -20,8 +20,8 @@ end
2020
function DiffEqBase.interp_summary(::Type{cacheType},
2121
dense::Bool) where {
2222
cacheType <:
23-
Union{Rosenbrock5ConstantCache,
24-
Rosenbrock5Cache}}
23+
Union{RosenbrockCombinedConstantCache,
24+
RosenbrockCache}}
2525
dense ? "specialized 4rd order \"free\" stiffness-aware interpolation" :
2626
"1st order linear"
2727
end

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl

Lines changed: 26 additions & 201 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,24 @@ mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabT
4141
alg::A
4242
step_limiter!::StepLimiter
4343
stage_limiter!::StageLimiter
44+
interp_order::Int
4445
end
4546
function full_cache(c::RosenbrockCache)
4647
return [c.u, c.uprev, c.dense..., c.du, c.du1, c.du2,
4748
c.ks..., c.fsalfirst, c.fsallast, c.dT, c.tmp, c.atmp, c.weight, c.linsolve_tmp]
4849
end
4950

51+
struct RosenbrockCombinedConstantCache{TF, UF, Tab, JType, WType, F, AD} <: RosenbrockConstantCache
52+
tf::TF
53+
uf::UF
54+
tab::Tab
55+
J::JType
56+
W::WType
57+
linsolve::F
58+
autodiff::AD
59+
interp_order::Int
60+
end
61+
5062
@cache mutable struct Rosenbrock23Cache{uType, rateType, uNoUnitsType, JType, WType,
5163
TabType, TFType, UFType, F, JCType, GCType,
5264
RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache
@@ -702,22 +714,16 @@ end
702714

703715
### Rodas4 methods
704716

705-
struct Rodas4ConstantCache{TF, UF, Tab, JType, WType, F, AD} <: RosenbrockConstantCache
706-
tf::TF
707-
uf::UF
708-
tab::Tab
709-
J::JType
710-
W::WType
711-
linsolve::F
712-
autodiff::AD
713-
end
714-
715717
tabtype(::Rodas4) = Rodas4Tableau
716718
tabtype(::Rodas42) = Rodas42Tableau
717719
tabtype(::Rodas4P) = Rodas4PTableau
718720
tabtype(::Rodas4P2) = Rodas4P2Tableau
721+
tabtype(::Rodas5) = Rodas5Tableau
722+
tabtype(::Rodas5P) = Rodas5PTableau
723+
tabtype(::Rodas5Pr) = Rodas5PTableau
724+
tabtype(::Rodas5Pe) = Rodas5PTableau
719725

720-
function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2},
726+
function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr},
721727
u, rate_prototype, ::Type{uEltypeNoUnits},
722728
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
723729
dt, reltol, p, calck,
@@ -727,21 +733,22 @@ function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2},
727733
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false))
728734
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
729735
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
730-
Rodas4ConstantCache(tf, uf,
731-
tabtype(alg)(constvalue(uBottomEltypeNoUnits),
732-
constvalue(tTypeNoUnits)), J, W, linsolve,
733-
alg_autodiff(alg))
736+
tab = tabtype(alg)(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))
737+
RosenbrockCombinedConstantCache(tf, uf,
738+
tab, J, W, linsolve,
739+
alg_autodiff(alg), size(tab.H, 1))
734740
end
735741

736-
function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2},
742+
function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr},
737743
u, rate_prototype, ::Type{uEltypeNoUnits},
738744
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
739745
dt, reltol, p, calck,
740746
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
741747

748+
tab = tabtype(alg)(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))
742749
# Initialize vectors
743-
dense = [zero(rate_prototype) for _ in 1:2]
744-
ks = [zero(rate_prototype) for _ in 1:6]
750+
dense = [zero(rate_prototype) for _ in 1:size(tab.H, 1)]
751+
ks = [zero(rate_prototype) for _ in 1:size(tab.A, 1)]
745752
du = zero(rate_prototype)
746753
du1 = zero(rate_prototype)
747754
du2 = zero(rate_prototype)
@@ -760,7 +767,6 @@ function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2},
760767
recursivefill!(atmp, false)
761768
weight = similar(u, uEltypeNoUnits)
762769
recursivefill!(weight, false)
763-
tab = tabtype(alg)(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))
764770

765771
tf = TimeGradientWrapper(f, uprev, p)
766772
uf = UJacobianWrapper(f, t, p)
@@ -783,190 +789,9 @@ function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2},
783789
u, uprev, dense, du, du1, du2, ks, fsalfirst, fsallast,
784790
dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
785791
linsolve, jac_config, grad_config, reltol, alg,
786-
alg.step_limiter!, alg.stage_limiter!)
787-
end
788-
789-
################################################################################
790-
791-
### Rosenbrock5
792-
793-
struct Rosenbrock5ConstantCache{TF, UF, Tab, JType, WType, F} <: RosenbrockConstantCache
794-
tf::TF
795-
uf::UF
796-
tab::Tab
797-
J::JType
798-
W::WType
799-
linsolve::F
800-
end
801-
802-
@cache mutable struct Rosenbrock5Cache{
803-
uType, rateType, uNoUnitsType, JType, WType, TabType,
804-
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter} <:
805-
RosenbrockMutableCache
806-
u::uType
807-
uprev::uType
808-
dense1::rateType
809-
dense2::rateType
810-
dense3::rateType
811-
du::rateType
812-
du1::rateType
813-
du2::rateType
814-
k1::rateType
815-
k2::rateType
816-
k3::rateType
817-
k4::rateType
818-
k5::rateType
819-
k6::rateType
820-
k7::rateType
821-
k8::rateType
822-
fsalfirst::rateType
823-
fsallast::rateType
824-
dT::rateType
825-
J::JType
826-
W::WType
827-
tmp::rateType
828-
atmp::uNoUnitsType
829-
weight::uNoUnitsType
830-
tab::TabType
831-
tf::TFType
832-
uf::UFType
833-
linsolve_tmp::rateType
834-
linsolve::F
835-
jac_config::JCType
836-
grad_config::GCType
837-
reltol::RTolType
838-
alg::A
839-
step_limiter!::StepLimiter
840-
stage_limiter!::StageLimiter
792+
alg.step_limiter!, alg.stage_limiter!, size(tab.H, 1))
841793
end
842794

843-
function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
844-
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
845-
dt, reltol, p, calck,
846-
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
847-
dense1 = zero(rate_prototype)
848-
dense2 = zero(rate_prototype)
849-
dense3 = zero(rate_prototype)
850-
du = zero(rate_prototype)
851-
du1 = zero(rate_prototype)
852-
du2 = zero(rate_prototype)
853-
k1 = zero(rate_prototype)
854-
k2 = zero(rate_prototype)
855-
k3 = zero(rate_prototype)
856-
k4 = zero(rate_prototype)
857-
k5 = zero(rate_prototype)
858-
k6 = zero(rate_prototype)
859-
k7 = zero(rate_prototype)
860-
k8 = zero(rate_prototype)
861-
fsalfirst = zero(rate_prototype)
862-
fsallast = zero(rate_prototype)
863-
dT = zero(rate_prototype)
864-
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true))
865-
tmp = zero(rate_prototype)
866-
atmp = similar(u, uEltypeNoUnits)
867-
recursivefill!(atmp, false)
868-
weight = similar(u, uEltypeNoUnits)
869-
recursivefill!(weight, false)
870-
tab = Rodas5Tableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))
871-
872-
tf = TimeGradientWrapper(f, uprev, p)
873-
uf = UJacobianWrapper(f, t, p)
874-
linsolve_tmp = zero(rate_prototype)
875-
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
876-
Pl, Pr = wrapprecs(
877-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
878-
nothing)..., weight, tmp)
879-
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
880-
Pl = Pl, Pr = Pr,
881-
assumptions = LinearSolve.OperatorAssumptions(true))
882-
grad_config = build_grad_config(alg, f, tf, du1, t)
883-
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
884-
Rosenbrock5Cache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4,
885-
k5, k6, k7, k8,
886-
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
887-
linsolve_tmp,
888-
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
889-
alg.stage_limiter!)
890-
end
891-
892-
function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
893-
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
894-
dt, reltol, p, calck,
895-
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
896-
tf = TimeDerivativeWrapper(f, u, p)
897-
uf = UDerivativeWrapper(f, t, p)
898-
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false))
899-
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
900-
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
901-
Rosenbrock5ConstantCache(tf, uf,
902-
Rodas5Tableau(constvalue(uBottomEltypeNoUnits),
903-
constvalue(tTypeNoUnits)), J, W, linsolve)
904-
end
905-
906-
function alg_cache(
907-
alg::Union{Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits},
908-
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
909-
dt, reltol, p, calck,
910-
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
911-
dense1 = zero(rate_prototype)
912-
dense2 = zero(rate_prototype)
913-
dense3 = zero(rate_prototype)
914-
du = zero(rate_prototype)
915-
du1 = zero(rate_prototype)
916-
du2 = zero(rate_prototype)
917-
k1 = zero(rate_prototype)
918-
k2 = zero(rate_prototype)
919-
k3 = zero(rate_prototype)
920-
k4 = zero(rate_prototype)
921-
k5 = zero(rate_prototype)
922-
k6 = zero(rate_prototype)
923-
k7 = zero(rate_prototype)
924-
k8 = zero(rate_prototype)
925-
fsalfirst = zero(rate_prototype)
926-
fsallast = zero(rate_prototype)
927-
dT = zero(rate_prototype)
928-
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true))
929-
tmp = zero(rate_prototype)
930-
atmp = similar(u, uEltypeNoUnits)
931-
recursivefill!(atmp, false)
932-
weight = similar(u, uEltypeNoUnits)
933-
recursivefill!(weight, false)
934-
tab = Rodas5PTableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))
935-
936-
tf = TimeGradientWrapper(f, uprev, p)
937-
uf = UJacobianWrapper(f, t, p)
938-
linsolve_tmp = zero(rate_prototype)
939-
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
940-
Pl, Pr = wrapprecs(
941-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
942-
nothing)..., weight, tmp)
943-
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
944-
Pl = Pl, Pr = Pr,
945-
assumptions = LinearSolve.OperatorAssumptions(true))
946-
grad_config = build_grad_config(alg, f, tf, du1, t)
947-
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
948-
Rosenbrock5Cache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4,
949-
k5, k6, k7, k8,
950-
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
951-
linsolve_tmp,
952-
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
953-
alg.stage_limiter!)
954-
end
955-
956-
function alg_cache(
957-
alg::Union{Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits},
958-
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
959-
dt, reltol, p, calck,
960-
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
961-
tf = TimeDerivativeWrapper(f, u, p)
962-
uf = UDerivativeWrapper(f, t, p)
963-
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false))
964-
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
965-
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
966-
Rosenbrock5ConstantCache(tf, uf,
967-
Rodas5PTableau(constvalue(uBottomEltypeNoUnits),
968-
constvalue(tTypeNoUnits)), J, W, linsolve)
969-
end
970795

971796
function get_fsalfirstlast(
972797
cache::Union{Rosenbrock23Cache, Rosenbrock32Cache, Rosenbrock33Cache,

0 commit comments

Comments
 (0)