@@ -8,6 +8,220 @@ function get_fsalfirstlast(cache::GenericRosenbrockMutableCache, u)
8
8
(cache. fsalfirst, cache. fsallast)
9
9
end
10
10
11
+ @cache mutable struct Rosenbrock23Cache{uType, rateType, uNoUnitsType, JType, WType,
12
+ TabType, TFType, UFType, F, JCType, GCType,
13
+ RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache
14
+ u:: uType
15
+ uprev:: uType
16
+ k₁:: rateType
17
+ k₂:: rateType
18
+ k₃:: rateType
19
+ du1:: rateType
20
+ du2:: rateType
21
+ f₁:: rateType
22
+ fsalfirst:: rateType
23
+ fsallast:: rateType
24
+ dT:: rateType
25
+ J:: JType
26
+ W:: WType
27
+ tmp:: rateType
28
+ atmp:: uNoUnitsType
29
+ weight:: uNoUnitsType
30
+ tab:: TabType
31
+ tf:: TFType
32
+ uf:: UFType
33
+ linsolve_tmp:: rateType
34
+ linsolve:: F
35
+ jac_config:: JCType
36
+ grad_config:: GCType
37
+ reltol:: RTolType
38
+ alg:: A
39
+ algebraic_vars:: AV
40
+ step_limiter!:: StepLimiter
41
+ stage_limiter!:: StageLimiter
42
+ end
43
+
44
+ @cache mutable struct Rosenbrock32Cache{uType, rateType, uNoUnitsType, JType, WType,
45
+ TabType, TFType, UFType, F, JCType, GCType,
46
+ RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache
47
+ u:: uType
48
+ uprev:: uType
49
+ k₁:: rateType
50
+ k₂:: rateType
51
+ k₃:: rateType
52
+ du1:: rateType
53
+ du2:: rateType
54
+ f₁:: rateType
55
+ fsalfirst:: rateType
56
+ fsallast:: rateType
57
+ dT:: rateType
58
+ J:: JType
59
+ W:: WType
60
+ tmp:: rateType
61
+ atmp:: uNoUnitsType
62
+ weight:: uNoUnitsType
63
+ tab:: TabType
64
+ tf:: TFType
65
+ uf:: UFType
66
+ linsolve_tmp:: rateType
67
+ linsolve:: F
68
+ jac_config:: JCType
69
+ grad_config:: GCType
70
+ reltol:: RTolType
71
+ alg:: A
72
+ algebraic_vars:: AV
73
+ step_limiter!:: StepLimiter
74
+ stage_limiter!:: StageLimiter
75
+ end
76
+
77
+ function alg_cache (alg:: Rosenbrock23 , u, rate_prototype, :: Type{uEltypeNoUnits} ,
78
+ :: Type{uBottomEltypeNoUnits} , :: Type{tTypeNoUnits} , uprev, uprev2, f, t,
79
+ dt, reltol, p, calck,
80
+ :: Val{true} ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
81
+ k₁ = zero (rate_prototype)
82
+ k₂ = zero (rate_prototype)
83
+ k₃ = zero (rate_prototype)
84
+ du1 = zero (rate_prototype)
85
+ du2 = zero (rate_prototype)
86
+ # f₀ = zero(u) fsalfirst
87
+ f₁ = zero (rate_prototype)
88
+ fsalfirst = zero (rate_prototype)
89
+ fsallast = zero (rate_prototype)
90
+ dT = zero (rate_prototype)
91
+ J, W = build_J_W (alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val (true ))
92
+ tmp = zero (rate_prototype)
93
+ atmp = similar (u, uEltypeNoUnits)
94
+ recursivefill! (atmp, false )
95
+ weight = similar (u, uEltypeNoUnits)
96
+ recursivefill! (weight, false )
97
+ tab = Rosenbrock23Tableau (constvalue (uBottomEltypeNoUnits))
98
+ tf = TimeGradientWrapper (f, uprev, p)
99
+ uf = UJacobianWrapper (f, t, p)
100
+ linsolve_tmp = zero (rate_prototype)
101
+
102
+ linprob = LinearProblem (W, _vec (linsolve_tmp); u0 = _vec (tmp))
103
+ Pl, Pr = wrapprecs (
104
+ alg. precs (W, nothing , u, p, t, nothing , nothing , nothing ,
105
+ nothing )... , weight, tmp)
106
+ linsolve = init (linprob, alg. linsolve, alias_A = true , alias_b = true ,
107
+ Pl = Pl, Pr = Pr,
108
+ assumptions = LinearSolve. OperatorAssumptions (true ))
109
+
110
+ grad_config = build_grad_config (alg, f, tf, du1, t)
111
+ jac_config = build_jac_config (alg, f, uf, du1, uprev, u, tmp, du2)
112
+ algebraic_vars = f. mass_matrix === I ? nothing :
113
+ [all (iszero, x) for x in eachcol (f. mass_matrix)]
114
+
115
+ Rosenbrock23Cache (u, uprev, k₁, k₂, k₃, du1, du2, f₁,
116
+ fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
117
+ linsolve_tmp,
118
+ linsolve, jac_config, grad_config, reltol, alg, algebraic_vars, alg. step_limiter!,
119
+ alg. stage_limiter!)
120
+ end
121
+
122
+ function alg_cache (alg:: Rosenbrock32 , u, rate_prototype, :: Type{uEltypeNoUnits} ,
123
+ :: Type{uBottomEltypeNoUnits} , :: Type{tTypeNoUnits} , uprev, uprev2, f, t,
124
+ dt, reltol, p, calck,
125
+ :: Val{true} ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
126
+ k₁ = zero (rate_prototype)
127
+ k₂ = zero (rate_prototype)
128
+ k₃ = zero (rate_prototype)
129
+ du1 = zero (rate_prototype)
130
+ du2 = zero (rate_prototype)
131
+ # f₀ = zero(u) fsalfirst
132
+ f₁ = zero (rate_prototype)
133
+ fsalfirst = zero (rate_prototype)
134
+ fsallast = zero (rate_prototype)
135
+ dT = zero (rate_prototype)
136
+ J, W = build_J_W (alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val (true ))
137
+ tmp = zero (rate_prototype)
138
+ atmp = similar (u, uEltypeNoUnits)
139
+ recursivefill! (atmp, false )
140
+ weight = similar (u, uEltypeNoUnits)
141
+ recursivefill! (weight, false )
142
+ tab = Rosenbrock32Tableau (constvalue (uBottomEltypeNoUnits))
143
+
144
+ tf = TimeGradientWrapper (f, uprev, p)
145
+ uf = UJacobianWrapper (f, t, p)
146
+ linsolve_tmp = zero (rate_prototype)
147
+ linprob = LinearProblem (W, _vec (linsolve_tmp); u0 = _vec (tmp))
148
+
149
+ Pl, Pr = wrapprecs (
150
+ alg. precs (W, nothing , u, p, t, nothing , nothing , nothing ,
151
+ nothing )... , weight, tmp)
152
+ linsolve = init (linprob, alg. linsolve, alias_A = true , alias_b = true ,
153
+ Pl = Pl, Pr = Pr,
154
+ assumptions = LinearSolve. OperatorAssumptions (true ))
155
+ grad_config = build_grad_config (alg, f, tf, du1, t)
156
+ jac_config = build_jac_config (alg, f, uf, du1, uprev, u, tmp, du2)
157
+ algebraic_vars = f. mass_matrix === I ? nothing :
158
+ [all (iszero, x) for x in eachcol (f. mass_matrix)]
159
+
160
+ Rosenbrock32Cache (u, uprev, k₁, k₂, k₃, du1, du2, f₁, fsalfirst, fsallast, dT, J, W,
161
+ tmp, atmp, weight, tab, tf, uf, linsolve_tmp, linsolve, jac_config,
162
+ grad_config, reltol, alg, algebraic_vars, alg. step_limiter!, alg. stage_limiter!)
163
+ end
164
+
165
+ struct Rosenbrock23ConstantCache{T, TF, UF, JType, WType, F, AD} < :
166
+ RosenbrockConstantCache
167
+ c₃₂:: T
168
+ d:: T
169
+ tf:: TF
170
+ uf:: UF
171
+ J:: JType
172
+ W:: WType
173
+ linsolve:: F
174
+ autodiff:: AD
175
+ end
176
+
177
+ function Rosenbrock23ConstantCache (:: Type{T} , tf, uf, J, W, linsolve, autodiff) where {T}
178
+ tab = Rosenbrock23Tableau (T)
179
+ Rosenbrock23ConstantCache (tab. c₃₂, tab. d, tf, uf, J, W, linsolve, autodiff)
180
+ end
181
+
182
+ function alg_cache (alg:: Rosenbrock23 , u, rate_prototype, :: Type{uEltypeNoUnits} ,
183
+ :: Type{uBottomEltypeNoUnits} , :: Type{tTypeNoUnits} , uprev, uprev2, f, t,
184
+ dt, reltol, p, calck,
185
+ :: Val{false} ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
186
+ tf = TimeDerivativeWrapper (f, u, p)
187
+ uf = UDerivativeWrapper (f, t, p)
188
+ J, W = build_J_W (alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val (false ))
189
+ linprob = nothing # LinearProblem(W,copy(u); u0=copy(u))
190
+ linsolve = nothing # init(linprob,alg.linsolve,alias_A=true,alias_b=true)
191
+ Rosenbrock23ConstantCache (constvalue (uBottomEltypeNoUnits), tf, uf, J, W, linsolve,
192
+ alg_autodiff (alg))
193
+ end
194
+
195
+ struct Rosenbrock32ConstantCache{T, TF, UF, JType, WType, F, AD} < :
196
+ RosenbrockConstantCache
197
+ c₃₂:: T
198
+ d:: T
199
+ tf:: TF
200
+ uf:: UF
201
+ J:: JType
202
+ W:: WType
203
+ linsolve:: F
204
+ autodiff:: AD
205
+ end
206
+
207
+ function Rosenbrock32ConstantCache (:: Type{T} , tf, uf, J, W, linsolve, autodiff) where {T}
208
+ tab = Rosenbrock32Tableau (T)
209
+ Rosenbrock32ConstantCache (tab. c₃₂, tab. d, tf, uf, J, W, linsolve, autodiff)
210
+ end
211
+
212
+ function alg_cache (alg:: Rosenbrock32 , u, rate_prototype, :: Type{uEltypeNoUnits} ,
213
+ :: Type{uBottomEltypeNoUnits} , :: Type{tTypeNoUnits} , uprev, uprev2, f, t,
214
+ dt, reltol, p, calck,
215
+ :: Val{false} ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
216
+ tf = TimeDerivativeWrapper (f, u, p)
217
+ uf = UDerivativeWrapper (f, t, p)
218
+ J, W = build_J_W (alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val (false ))
219
+ linprob = nothing # LinearProblem(W,copy(u); u0=copy(u))
220
+ linsolve = nothing # init(linprob,alg.linsolve,alias_A=true,alias_b=true)
221
+ Rosenbrock32ConstantCache (constvalue (uBottomEltypeNoUnits), tf, uf, J, W, linsolve,
222
+ alg_autodiff (alg))
223
+ end
224
+
11
225
# ###############################################################################
12
226
13
227
# Shampine's Low-order Rosenbrocks
@@ -19,7 +233,6 @@ mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabT
19
233
du:: rateType
20
234
du1:: rateType
21
235
du2:: rateType
22
- f₁:: rateType
23
236
ks:: Vector{rateType}
24
237
fsalfirst:: rateType
25
238
fsallast:: rateType
@@ -43,6 +256,7 @@ mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabT
43
256
stage_limiter!:: StageLimiter
44
257
interp_order:: Int
45
258
end
259
+
46
260
function full_cache (c:: RosenbrockCache )
47
261
return [c. u, c. uprev, c. dense... , c. du, c. du1, c. du2,
48
262
c. ks... , c. fsalfirst, c. fsallast, c. dT, c. tmp, c. atmp, c. weight, c. linsolve_tmp]
@@ -98,19 +312,18 @@ tabtype(::Rodas5Pr) = Rodas5PTableau
98
312
tabtype (:: Rodas5Pe ) = Rodas5PTableau
99
313
100
314
function alg_cache (
101
- alg:: Union{Rosenbrock23, Rosenbrock32, ROS3P, Rodas3, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr} ,
315
+ alg:: Union{ROS3P, Rodas3, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr} ,
102
316
u, rate_prototype, :: Type{uEltypeNoUnits} ,
103
317
:: Type{uBottomEltypeNoUnits} , :: Type{tTypeNoUnits} , uprev, uprev2, f, t,
104
318
dt, reltol, p, calck,
105
319
:: Val{true} ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
106
- tab = Rodas5PTableau (constvalue (uBottomEltypeNoUnits), constvalue (tTypeNoUnits))
320
+ tab = tabtype (alg) (constvalue (uBottomEltypeNoUnits), constvalue (tTypeNoUnits))
107
321
dense = [zero (rate_prototype) for _ in 1 : size (tab. H, 1 )]
108
322
du = zero (rate_prototype)
109
323
du1 = zero (rate_prototype)
110
324
du2 = zero (rate_prototype)
111
325
ks = [zero (rate_prototype) for _ in 1 : size (tab. A, 1 )]
112
326
113
- f₁ = zero (rate_prototype)
114
327
fsalfirst = zero (rate_prototype)
115
328
fsallast = zero (rate_prototype)
116
329
dT = zero (rate_prototype)
@@ -135,14 +348,14 @@ function alg_cache(
135
348
jac_config = build_jac_config (alg, f, uf, du1, uprev, u, tmp, du2)
136
349
algebraic_vars = f. mass_matrix === I ? nothing :
137
350
[all (iszero, x) for x in eachcol (f. mass_matrix)]
138
- RosenbrockCache (u, uprev, dense, du, du1, du2, ks, f₁, fsalfirst, fsallast,
351
+ RosenbrockCache (u, uprev, dense, du, du1, du2, ks, fsalfirst, fsallast,
139
352
dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
140
353
linsolve, jac_config, grad_config, reltol, alg, algebraic_vars,
141
354
alg. step_limiter!, alg. stage_limiter!, size (tab. H, 1 ))
142
355
end
143
356
144
357
function alg_cache (
145
- alg:: Union{Rosenbrock23, Rosenbrock32, ROS3P, Rodas3, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr} ,
358
+ alg:: Union{ROS3P, Rodas3, Rodas23W, Rodas4, Rodas42, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr} ,
146
359
u, rate_prototype, :: Type{uEltypeNoUnits} ,
147
360
:: Type{uBottomEltypeNoUnits} , :: Type{tTypeNoUnits} , uprev, uprev2, f, t,
148
361
dt, reltol, p, calck,
@@ -152,7 +365,7 @@ function alg_cache(
152
365
J, W = build_J_W (alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val (false ))
153
366
linprob = nothing # LinearProblem(W,copy(u); u0=copy(u))
154
367
linsolve = nothing # init(linprob,alg.linsolve,alias_A=true,alias_b=true)
155
- tab =
368
+ tab = tabtype (alg)( constvalue (uBottomEltypeNoUnits), constvalue (tTypeNoUnits))
156
369
RosenbrockCombinedConstantCache (tf, uf, tab, J, W, linsolve, alg_autodiff (alg), size (tab. H, 1 ))
157
370
end
158
371
0 commit comments