33
33
mass_matrix = integrator. f. mass_matrix
34
34
35
35
# Precalculations
36
- γ = dt * d
36
+ dtγ = dt * d
37
+ neginvdtγ = - inv (dtγ)
37
38
dto2 = dt / 2
38
39
dto6 = dt / 6
39
40
42
43
OrdinaryDiffEqCore. increment_nf! (integrator. stats, 1 )
43
44
end
44
45
45
- calc_rosenbrock_differentiation! (integrator, cache, γ, γ , repeat_step, false )
46
+ calc_rosenbrock_differentiation! (integrator, cache, dtγ, dtγ , repeat_step, true )
46
47
47
48
calculate_residuals! (weight, fill! (weight, one (eltype (u))), uprev, uprev,
48
49
integrator. opts. abstol, integrator. opts. reltol,
52
53
linres = dolinsolve (
53
54
integrator, cache. linsolve; A = nothing , b = _vec (linsolve_tmp),
54
55
du = integrator. fsalfirst, u = u, p = p, t = t, weight = weight,
55
- solverdata = (; gamma = γ ))
56
+ solverdata = (; gamma = dtγ ))
56
57
else
57
58
linres = dolinsolve (integrator, cache. linsolve; A = W, b = _vec (linsolve_tmp),
58
59
du = integrator. fsalfirst, u = u, p = p, t = t, weight = weight,
59
- solverdata = (; gamma = γ ))
60
+ solverdata = (; gamma = dtγ ))
60
61
end
61
62
62
63
vecu = _vec (linres. u)
63
64
veck₁ = _vec (k₁)
64
65
65
- @. . broadcast = false veck₁= - vecu
66
+ @. . veck₁ = vecu * neginvdtγ
66
67
integrator. stats. nsolve += 1
67
68
68
- @. . broadcast = false u= uprev + dto2 * k₁
69
+ @. . u= uprev + dto2 * k₁
69
70
stage_limiter! (u, integrator, p, t + dto2)
70
71
f (f₁, u, p, t + dto2)
71
72
OrdinaryDiffEqCore. increment_nf! (integrator. stats, 1 )
76
77
mul! (_vec (tmp), mass_matrix, _vec (k₁))
77
78
end
78
79
79
- @. . broadcast = false linsolve_tmp= f₁ - tmp
80
+ @. . linsolve_tmp = f₁ - tmp
80
81
81
82
linres = dolinsolve (integrator, linres. cache; b = _vec (linsolve_tmp))
82
83
vecu = _vec (linres. u)
83
- veck2 = _vec (k₂)
84
+ veck₂ = _vec (k₂)
84
85
85
- @. . broadcast = false veck2 = - vecu
86
+ @. . veck₂ = vecu * neginvdtγ + veck₁
86
87
integrator. stats. nsolve += 1
87
88
88
- @. . broadcast= false k₂+= k₁
89
- @. . broadcast= false u= uprev + dt * k₂
89
+ @. . u = uprev + dt * k₂
90
90
stage_limiter! (u, integrator, p, t + dt)
91
91
step_limiter! (u, integrator, p, t + dt)
92
92
107
107
linres = dolinsolve (integrator, linres. cache; b = _vec (linsolve_tmp))
108
108
vecu = _vec (linres. u)
109
109
veck3 = _vec (k₃)
110
- @. . broadcast = false veck3= - vecu
110
+ @. . veck3 = vecu * neginvdtγ
111
111
112
112
integrator. stats. nsolve += 1
113
113
127
127
128
128
if mass_matrix != = I
129
129
algvar = reshape (cache. algebraic_vars, size (u))
130
- @. . broadcast = false atmp = ifelse (algvar, fsallast, false ) /
131
- integrator . opts . abstol
130
+ invatol = inv (integrator . opts . abstol)
131
+ @. . atmp = ifelse (algvar, fsallast, false ) * invatol
132
132
integrator. EEst += integrator. opts. internalnorm (atmp, t)
133
133
end
134
134
end
145
145
mass_matrix = integrator. f. mass_matrix
146
146
147
147
# Precalculations
148
- γ = dt * d
148
+ dtγ = dt * d
149
+ neginvdtγ = - inv (dtγ)
149
150
dto2 = dt / 2
150
151
dto6 = dt / 6
151
152
154
155
OrdinaryDiffEqCore. increment_nf! (integrator. stats, 1 )
155
156
end
156
157
157
- calc_rosenbrock_differentiation! (integrator, cache, γ, γ , repeat_step, false )
158
+ calc_rosenbrock_differentiation! (integrator, cache, dtγ, dtγ , repeat_step, true )
158
159
159
160
calculate_residuals! (weight, fill! (weight, one (eltype (u))), uprev, uprev,
160
161
integrator. opts. abstol, integrator. opts. reltol,
@@ -164,17 +165,17 @@ end
164
165
linres = dolinsolve (
165
166
integrator, cache. linsolve; A = nothing , b = _vec (linsolve_tmp),
166
167
du = integrator. fsalfirst, u = u, p = p, t = t, weight = weight,
167
- solverdata = (; gamma = γ ))
168
+ solverdata = (; gamma = dtγ ))
168
169
else
169
170
linres = dolinsolve (integrator, cache. linsolve; A = W, b = _vec (linsolve_tmp),
170
171
du = integrator. fsalfirst, u = u, p = p, t = t, weight = weight,
171
- solverdata = (; gamma = γ ))
172
+ solverdata = (; gamma = dtγ ))
172
173
end
173
174
174
175
vecu = _vec (linres. u)
175
176
veck₁ = _vec (k₁)
176
177
177
- @. . broadcast = false veck₁= - vecu
178
+ @. . veck₁ = vecu * neginvdtγ
178
179
integrator. stats. nsolve += 1
179
180
180
181
@. . broadcast= false u= uprev + dto2 * k₁
@@ -192,13 +193,12 @@ end
192
193
193
194
linres = dolinsolve (integrator, linres. cache; b = _vec (linsolve_tmp))
194
195
vecu = _vec (linres. u)
195
- veck2 = _vec (k₂)
196
+ veck₂ = _vec (k₂)
196
197
197
- @. . broadcast = false veck2 = - vecu
198
+ @. . veck₂ = vecu * neginvdtγ + veck₁
198
199
integrator. stats. nsolve += 1
199
200
200
- @. . broadcast= false k₂+= k₁
201
- @. . broadcast= false tmp= uprev + dt * k₂
201
+ @. . tmp = uprev + dt * k₂
202
202
stage_limiter! (u, integrator, p, t + dt)
203
203
f (fsallast, tmp, p, t + dt)
204
204
OrdinaryDiffEqCore. increment_nf! (integrator. stats, 1 )
216
216
vecu = _vec (linres. u)
217
217
veck3 = _vec (k₃)
218
218
219
- @. . broadcast = false veck3= - vecu
219
+ @. . veck3 = vecu * neginvdtγ
220
220
integrator. stats. nsolve += 1
221
221
222
222
@. . broadcast= false u= uprev + dto6 * (k₁ + 4 k₂ + k₃)
230
230
integrator. EEst = integrator. opts. internalnorm (atmp, t)
231
231
232
232
if mass_matrix != = I
233
- @. . broadcast = false atmp = ifelse (cache . algebraic_vars, fsallast, false ) /
234
- integrator . opts . abstol
233
+ invatol = inv (integrator . opts . abstol)
234
+ @. . atmp = ifelse (cache . algebraic_vars, fsallast, false ) * invatol
235
235
integrator. EEst += integrator. opts. internalnorm (atmp, t)
236
236
end
237
237
end
244
244
@unpack c₃₂, d, tf, uf = cache
245
245
246
246
# Precalculations
247
- γ = dt * d
247
+ dtγ = dt * d
248
+ neginvdtγ = - inv (dtγ)
248
249
dto2 = dt / 2
249
250
dto6 = dt / 6
250
251
@@ -258,22 +259,24 @@ end
258
259
# Time derivative
259
260
dT = calc_tderivative (integrator, cache)
260
261
261
- W = calc_W (integrator, cache, γ , repeat_step)
262
+ W = calc_W (integrator, cache, dtγ , repeat_step, true )
262
263
if ! issuccess_W (W)
263
264
integrator. EEst = 2
264
265
return nothing
265
266
end
266
267
267
- k₁ = _reshape (W \ - _vec ((integrator. fsalfirst + γ * dT)), axes (uprev))
268
+ k₁ = _reshape (W \ _vec ((integrator. fsalfirst + dtγ * dT)), axes (uprev)) * neginvdtγ
268
269
integrator. stats. nsolve += 1
269
- f₁ = f (uprev + dto2 * k₁, p, t + dto2)
270
+ tmp = @. . uprev + dto2 * k₁
271
+ f₁ = f (tmp, p, t + dto2)
270
272
OrdinaryDiffEqCore. increment_nf! (integrator. stats, 1 )
271
273
272
274
if mass_matrix === I
273
- k₂ = _reshape (W \ - _vec (f₁ - k₁), axes (uprev)) + k₁
275
+ k₂ = _reshape (W \ _vec (f₁ - k₁), axes (uprev))
274
276
else
275
- k₂ = _reshape (W \ - _vec (f₁ - mass_matrix * k₁), axes (uprev)) + k₁
277
+ k₂ = _reshape (W \ _vec (f₁ - mass_matrix * k₁), axes (uprev))
276
278
end
279
+ k₂ = @. . k₂ * neginvdtγ + k₁
277
280
integrator. stats. nsolve += 1
278
281
u = uprev + dt * k₂
279
282
@@ -282,30 +285,28 @@ end
282
285
OrdinaryDiffEqCore. increment_nf! (integrator. stats, 1 )
283
286
284
287
if mass_matrix === I
285
- k₃ = _reshape (
286
- W \
287
- - _vec ((integrator. fsallast - c₃₂ * (k₂ - f₁) -
288
- 2 * (k₁ - integrator. fsalfirst) + dt * dT)),
289
- axes (uprev))
288
+ linsolve_tmp = @. . (integrator. fsallast - c₃₂ * (k₂ - f₁) -
289
+ 2 * (k₁ - integrator. fsalfirst) + dt * dT)
290
290
else
291
- linsolve_tmp = integrator . fsallast - mass_matrix * (c₃₂ * k₂ + 2 * k₁) +
292
- c₃₂ * f₁ + 2 * integrator. fsalfirst + dt * dT
293
- k₃ = _reshape (W \ - _vec (linsolve_tmp), axes (uprev) )
291
+ linsolve_tmp = mass_matrix * (@. . c₃₂ * k₂ + 2 * k₁)
292
+ linsolve_tmp = @. . ( integrator. fsallast - linsolve_tmp +
293
+ c₃₂ * f₁ + 2 * integrator . fsalfirst + dt * dT )
294
294
end
295
+ k₃ = _reshape (W \ _vec (linsolve_tmp), axes (uprev)) * neginvdtγ
295
296
integrator. stats. nsolve += 1
296
297
297
298
if u isa Number
298
299
utilde = dto6 * f. mass_matrix[1 , 1 ] * (k₁ - 2 * k₂ + k₃)
299
300
else
300
- utilde = dto6 * f . mass_matrix * (k₁ - 2 * k₂ + k₃)
301
+ utilde = f . mass_matrix * ( @. . dto6 * (k₁ - 2 * k₂ + k₃) )
301
302
end
302
303
atmp = calculate_residuals (utilde, uprev, u, integrator. opts. abstol,
303
304
integrator. opts. reltol, integrator. opts. internalnorm, t)
304
305
integrator. EEst = integrator. opts. internalnorm (atmp, t)
305
306
306
307
if mass_matrix != = I
307
- atmp = @. ifelse ( ! integrator. differential_vars, integrator . fsallast, false ) ./
308
- integrator. opts . abstol
308
+ invatol = inv ( integrator. opts . abstol)
309
+ atmp = @. ifelse (integrator . differential_vars, false , integrator. fsallast) * invatol
309
310
integrator. EEst += integrator. opts. internalnorm (atmp, t)
310
311
end
311
312
end
321
322
@unpack c₃₂, d, tf, uf = cache
322
323
323
324
# Precalculations
324
- γ = dt * d
325
+ dtγ = dt * d
326
+ neginvdtγ = - inv (dtγ)
325
327
dto2 = dt / 2
326
328
dto6 = dt / 6
327
329
@@ -335,52 +337,52 @@ end
335
337
# Time derivative
336
338
dT = calc_tderivative (integrator, cache)
337
339
338
- W = calc_W (integrator, cache, γ , repeat_step)
340
+ W = calc_W (integrator, cache, dtγ , repeat_step, true )
339
341
if ! issuccess_W (W)
340
342
integrator. EEst = 2
341
343
return nothing
342
344
end
343
345
344
- k₁ = _reshape (W \ - _vec ((integrator. fsalfirst + γ * dT)), axes (uprev))
346
+ k₁ = _reshape (W \ - _vec ((integrator. fsalfirst + dtγ * dT)), axes (uprev))/ dtγ
345
347
integrator. stats. nsolve += 1
346
- f₁ = f (uprev + dto2 * k₁, p, t + dto2)
348
+ tmp = @. . uprev + dto2 * k₁
349
+ f₁ = f (tmp, p, t + dto2)
347
350
OrdinaryDiffEqCore. increment_nf! (integrator. stats, 1 )
348
351
349
352
if mass_matrix === I
350
- k₂ = _reshape (W \ - _vec (f₁ - k₁), axes (uprev)) + k₁
353
+ k₂ = _reshape (W \ _vec (f₁ - k₁), axes (uprev))
351
354
else
352
355
linsolve_tmp = f₁ - mass_matrix * k₁
353
- k₂ = _reshape (W \ - _vec (linsolve_tmp), axes (uprev)) + k₁
356
+ k₂ = _reshape (W \ _vec (linsolve_tmp), axes (uprev))
354
357
end
358
+ k₂ = @. . k₂ * neginvdtγ + k₁
355
359
356
360
integrator. stats. nsolve += 1
357
- tmp = uprev + dt * k₂
361
+ tmp = @. . uprev + dt * k₂
358
362
integrator. fsallast = f (tmp, p, t + dt)
359
363
OrdinaryDiffEqCore. increment_nf! (integrator. stats, 1 )
360
364
361
365
if mass_matrix === I
362
- k₃ = _reshape (
363
- W \
364
- - _vec ((integrator. fsallast - c₃₂ * (k₂ - f₁) -
365
- 2 (k₁ - integrator. fsalfirst) + dt * dT)),
366
- axes (uprev))
366
+ linsolve_tmp = @. . (integrator. fsallast - c₃₂ * (k₂ - f₁) -
367
+ 2 (k₁ - integrator. fsalfirst) + dt * dT)
367
368
else
368
- linsolve_tmp = integrator . fsallast - mass_matrix * (c₃₂ * k₂ + 2 k₁) + c₃₂ * f₁ +
369
- 2 * integrator. fsalfirst + dt * dT
370
- k₃ = _reshape (W \ - _vec (linsolve_tmp), axes (uprev) )
369
+ linsolve_tmp = mass_matrix * (@. . c₃₂ * k₂ + 2 * k₁)
370
+ linsolve_tmp = @. . ( integrator. fsallast - linsolve_tmp +
371
+ c₃₂ * f₁ + 2 * integrator . fsalfirst + dt * dT )
371
372
end
373
+ k₃ = _reshape (W \ _vec (linsolve_tmp), axes (uprev)) * neginvdtγ
372
374
integrator. stats. nsolve += 1
373
- u = uprev + dto6 * (k₁ + 4 k₂ + k₃)
375
+ u = @. . uprev + dto6 * (k₁ + 4 k₂ + k₃)
374
376
375
377
if integrator. opts. adaptive
376
- utilde = dto6 * (k₁ - 2 k₂ + k₃)
378
+ utilde = @. . dto6 * (k₁ - 2 k₂ + k₃)
377
379
atmp = calculate_residuals (utilde, uprev, u, integrator. opts. abstol,
378
380
integrator. opts. reltol, integrator. opts. internalnorm, t)
379
381
integrator. EEst = integrator. opts. internalnorm (atmp, t)
380
382
381
383
if mass_matrix != = I
382
- atmp = @. ifelse ( ! integrator. differential_vars, integrator . fsallast, false ) ./
383
- integrator. opts . abstol
384
+ invatol = inv ( integrator. opts . abstol)
385
+ atmp = ifelse (integrator . differential_vars, false , integrator. fsallast) .* invatol
384
386
integrator. EEst += integrator. opts. internalnorm (atmp, t)
385
387
end
386
388
end
0 commit comments