Skip to content

Commit acd2414

Browse files
committed
don't update Rosenbrock23/32
1 parent 71233d2 commit acd2414

File tree

4 files changed

+709
-74
lines changed

4 files changed

+709
-74
lines changed

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl

Lines changed: 220 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,220 @@ function get_fsalfirstlast(cache::GenericRosenbrockMutableCache, u)
88
(cache.fsalfirst, cache.fsallast)
99
end
1010

11+
@cache mutable struct Rosenbrock23Cache{uType, rateType, uNoUnitsType, JType, WType,
12+
TabType, TFType, UFType, F, JCType, GCType,
13+
RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache
14+
u::uType
15+
uprev::uType
16+
k₁::rateType
17+
k₂::rateType
18+
k₃::rateType
19+
du1::rateType
20+
du2::rateType
21+
f₁::rateType
22+
fsalfirst::rateType
23+
fsallast::rateType
24+
dT::rateType
25+
J::JType
26+
W::WType
27+
tmp::rateType
28+
atmp::uNoUnitsType
29+
weight::uNoUnitsType
30+
tab::TabType
31+
tf::TFType
32+
uf::UFType
33+
linsolve_tmp::rateType
34+
linsolve::F
35+
jac_config::JCType
36+
grad_config::GCType
37+
reltol::RTolType
38+
alg::A
39+
algebraic_vars::AV
40+
step_limiter!::StepLimiter
41+
stage_limiter!::StageLimiter
42+
end
43+
44+
@cache mutable struct Rosenbrock32Cache{uType, rateType, uNoUnitsType, JType, WType,
45+
TabType, TFType, UFType, F, JCType, GCType,
46+
RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache
47+
u::uType
48+
uprev::uType
49+
k₁::rateType
50+
k₂::rateType
51+
k₃::rateType
52+
du1::rateType
53+
du2::rateType
54+
f₁::rateType
55+
fsalfirst::rateType
56+
fsallast::rateType
57+
dT::rateType
58+
J::JType
59+
W::WType
60+
tmp::rateType
61+
atmp::uNoUnitsType
62+
weight::uNoUnitsType
63+
tab::TabType
64+
tf::TFType
65+
uf::UFType
66+
linsolve_tmp::rateType
67+
linsolve::F
68+
jac_config::JCType
69+
grad_config::GCType
70+
reltol::RTolType
71+
alg::A
72+
algebraic_vars::AV
73+
step_limiter!::StepLimiter
74+
stage_limiter!::StageLimiter
75+
end
76+
77+
function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
78+
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
79+
dt, reltol, p, calck,
80+
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
81+
k₁ = zero(rate_prototype)
82+
k₂ = zero(rate_prototype)
83+
k₃ = zero(rate_prototype)
84+
du1 = zero(rate_prototype)
85+
du2 = zero(rate_prototype)
86+
# f₀ = zero(u) fsalfirst
87+
f₁ = zero(rate_prototype)
88+
fsalfirst = zero(rate_prototype)
89+
fsallast = zero(rate_prototype)
90+
dT = zero(rate_prototype)
91+
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true))
92+
tmp = zero(rate_prototype)
93+
atmp = similar(u, uEltypeNoUnits)
94+
recursivefill!(atmp, false)
95+
weight = similar(u, uEltypeNoUnits)
96+
recursivefill!(weight, false)
97+
tab = Rosenbrock23Tableau(constvalue(uBottomEltypeNoUnits))
98+
tf = TimeGradientWrapper(f, uprev, p)
99+
uf = UJacobianWrapper(f, t, p)
100+
linsolve_tmp = zero(rate_prototype)
101+
102+
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
103+
Pl, Pr = wrapprecs(
104+
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
105+
nothing)..., weight, tmp)
106+
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
107+
Pl = Pl, Pr = Pr,
108+
assumptions = LinearSolve.OperatorAssumptions(true))
109+
110+
grad_config = build_grad_config(alg, f, tf, du1, t)
111+
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
112+
algebraic_vars = f.mass_matrix === I ? nothing :
113+
[all(iszero, x) for x in eachcol(f.mass_matrix)]
114+
115+
Rosenbrock23Cache(u, uprev, k₁, k₂, k₃, du1, du2, f₁,
116+
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
117+
linsolve_tmp,
118+
linsolve, jac_config, grad_config, reltol, alg, algebraic_vars, alg.step_limiter!,
119+
alg.stage_limiter!)
120+
end
121+
122+
function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits},
123+
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
124+
dt, reltol, p, calck,
125+
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
126+
k₁ = zero(rate_prototype)
127+
k₂ = zero(rate_prototype)
128+
k₃ = zero(rate_prototype)
129+
du1 = zero(rate_prototype)
130+
du2 = zero(rate_prototype)
131+
# f₀ = zero(u) fsalfirst
132+
f₁ = zero(rate_prototype)
133+
fsalfirst = zero(rate_prototype)
134+
fsallast = zero(rate_prototype)
135+
dT = zero(rate_prototype)
136+
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true))
137+
tmp = zero(rate_prototype)
138+
atmp = similar(u, uEltypeNoUnits)
139+
recursivefill!(atmp, false)
140+
weight = similar(u, uEltypeNoUnits)
141+
recursivefill!(weight, false)
142+
tab = Rosenbrock32Tableau(constvalue(uBottomEltypeNoUnits))
143+
144+
tf = TimeGradientWrapper(f, uprev, p)
145+
uf = UJacobianWrapper(f, t, p)
146+
linsolve_tmp = zero(rate_prototype)
147+
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
148+
149+
Pl, Pr = wrapprecs(
150+
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
151+
nothing)..., weight, tmp)
152+
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
153+
Pl = Pl, Pr = Pr,
154+
assumptions = LinearSolve.OperatorAssumptions(true))
155+
grad_config = build_grad_config(alg, f, tf, du1, t)
156+
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
157+
algebraic_vars = f.mass_matrix === I ? nothing :
158+
[all(iszero, x) for x in eachcol(f.mass_matrix)]
159+
160+
Rosenbrock32Cache(u, uprev, k₁, k₂, k₃, du1, du2, f₁, fsalfirst, fsallast, dT, J, W,
161+
tmp, atmp, weight, tab, tf, uf, linsolve_tmp, linsolve, jac_config,
162+
grad_config, reltol, alg, algebraic_vars, alg.step_limiter!, alg.stage_limiter!)
163+
end
164+
165+
struct Rosenbrock23ConstantCache{T, TF, UF, JType, WType, F, AD} <:
166+
RosenbrockConstantCache
167+
c₃₂::T
168+
d::T
169+
tf::TF
170+
uf::UF
171+
J::JType
172+
W::WType
173+
linsolve::F
174+
autodiff::AD
175+
end
176+
177+
function Rosenbrock23ConstantCache(::Type{T}, tf, uf, J, W, linsolve, autodiff) where {T}
178+
tab = Rosenbrock23Tableau(T)
179+
Rosenbrock23ConstantCache(tab.c₃₂, tab.d, tf, uf, J, W, linsolve, autodiff)
180+
end
181+
182+
function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
183+
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
184+
dt, reltol, p, calck,
185+
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
186+
tf = TimeDerivativeWrapper(f, u, p)
187+
uf = UDerivativeWrapper(f, t, p)
188+
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false))
189+
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
190+
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
191+
Rosenbrock23ConstantCache(constvalue(uBottomEltypeNoUnits), tf, uf, J, W, linsolve,
192+
alg_autodiff(alg))
193+
end
194+
195+
struct Rosenbrock32ConstantCache{T, TF, UF, JType, WType, F, AD} <:
196+
RosenbrockConstantCache
197+
c₃₂::T
198+
d::T
199+
tf::TF
200+
uf::UF
201+
J::JType
202+
W::WType
203+
linsolve::F
204+
autodiff::AD
205+
end
206+
207+
function Rosenbrock32ConstantCache(::Type{T}, tf, uf, J, W, linsolve, autodiff) where {T}
208+
tab = Rosenbrock32Tableau(T)
209+
Rosenbrock32ConstantCache(tab.c₃₂, tab.d, tf, uf, J, W, linsolve, autodiff)
210+
end
211+
212+
function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits},
213+
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
214+
dt, reltol, p, calck,
215+
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
216+
tf = TimeDerivativeWrapper(f, u, p)
217+
uf = UDerivativeWrapper(f, t, p)
218+
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false))
219+
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
220+
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
221+
Rosenbrock32ConstantCache(constvalue(uBottomEltypeNoUnits), tf, uf, J, W, linsolve,
222+
alg_autodiff(alg))
223+
end
224+
11225
################################################################################
12226

13227
# Shampine's Low-order Rosenbrocks
@@ -19,7 +233,6 @@ mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabT
19233
du::rateType
20234
du1::rateType
21235
du2::rateType
22-
f₁::rateType
23236
ks::Vector{rateType}
24237
fsalfirst::rateType
25238
fsallast::rateType
@@ -43,6 +256,7 @@ mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabT
43256
stage_limiter!::StageLimiter
44257
interp_order::Int
45258
end
259+
46260
function full_cache(c::RosenbrockCache)
47261
return [c.u, c.uprev, c.dense..., c.du, c.du1, c.du2,
48262
c.ks..., c.fsalfirst, c.fsallast, c.dT, c.tmp, c.atmp, c.weight, c.linsolve_tmp]
@@ -98,19 +312,18 @@ tabtype(::Rodas5Pr) = Rodas5PTableau
98312
tabtype(::Rodas5Pe) = Rodas5PTableau
99313

100314
function alg_cache(
101-
alg::Union{Rosenbrock23, Rosenbrock32, ROS3P, Rodas3, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr},
315+
alg::Union{ROS3P, Rodas3, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr},
102316
u, rate_prototype, ::Type{uEltypeNoUnits},
103317
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
104318
dt, reltol, p, calck,
105319
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
106-
tab = Rodas5PTableau(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))
320+
tab = tabtype(alg)(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))
107321
dense = [zero(rate_prototype) for _ in 1:size(tab.H, 1)]
108322
du = zero(rate_prototype)
109323
du1 = zero(rate_prototype)
110324
du2 = zero(rate_prototype)
111325
ks = [zero(rate_prototype) for _ in 1:size(tab.A, 1)]
112326

113-
f₁ = zero(rate_prototype)
114327
fsalfirst = zero(rate_prototype)
115328
fsallast = zero(rate_prototype)
116329
dT = zero(rate_prototype)
@@ -135,14 +348,14 @@ function alg_cache(
135348
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
136349
algebraic_vars = f.mass_matrix === I ? nothing :
137350
[all(iszero, x) for x in eachcol(f.mass_matrix)]
138-
RosenbrockCache(u, uprev, dense, du, du1, du2, ks, f₁, fsalfirst, fsallast,
351+
RosenbrockCache(u, uprev, dense, du, du1, du2, ks, fsalfirst, fsallast,
139352
dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
140353
linsolve, jac_config, grad_config, reltol, alg, algebraic_vars,
141354
alg.step_limiter!, alg.stage_limiter!, size(tab.H, 1))
142355
end
143356

144357
function alg_cache(
145-
alg::Union{Rosenbrock23, Rosenbrock32, ROS3P, Rodas3, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr},
358+
alg::Union{ROS3P, Rodas3, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr},
146359
u, rate_prototype, ::Type{uEltypeNoUnits},
147360
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
148361
dt, reltol, p, calck,
@@ -152,7 +365,7 @@ function alg_cache(
152365
J, W = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(false))
153366
linprob = nothing #LinearProblem(W,copy(u); u0=copy(u))
154367
linsolve = nothing #init(linprob,alg.linsolve,alias_A=true,alias_b=true)
155-
tab =
368+
tab = tabtype(alg)(constvalue(uBottomEltypeNoUnits), constvalue(tTypeNoUnits))
156369
RosenbrockCombinedConstantCache(tf, uf, tab, J, W, linsolve, alg_autodiff(alg), size(tab.H, 1))
157370
end
158371

0 commit comments

Comments
 (0)