Skip to content

Commit e5dd0ff

Browse files
authored
Merge pull request #2307 from SciML/wtransform
Apply W-transform to Rosenbrock23
2 parents 52eca77 + b3d0d4a commit e5dd0ff

File tree

3 files changed

+96
-84
lines changed

3 files changed

+96
-84
lines changed

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
146146
assumptions = LinearSolve.OperatorAssumptions(true))
147147

148148
grad_config = build_grad_config(alg, f, tf, du1, t)
149-
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2, Val(false))
149+
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
150150
algebraic_vars = f.mass_matrix === I ? nothing :
151151
[all(iszero, x) for x in eachcol(f.mass_matrix)]
152152

@@ -191,7 +191,7 @@ function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits},
191191
Pl = Pl, Pr = Pr,
192192
assumptions = LinearSolve.OperatorAssumptions(true))
193193
grad_config = build_grad_config(alg, f, tf, du1, t)
194-
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2, Val(false))
194+
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
195195
algebraic_vars = f.mass_matrix === I ? nothing :
196196
[all(iszero, x) for x in eachcol(f.mass_matrix)]
197197

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl

Lines changed: 64 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ end
3333
mass_matrix = integrator.f.mass_matrix
3434

3535
# Precalculations
36-
γ = dt * d
36+
dtγ = dt * d
37+
neginvdtγ = -inv(dtγ)
3738
dto2 = dt / 2
3839
dto6 = dt / 6
3940

@@ -42,7 +43,7 @@ end
4243
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
4344
end
4445

45-
calc_rosenbrock_differentiation!(integrator, cache, γ, γ, repeat_step, false)
46+
calc_rosenbrock_differentiation!(integrator, cache, dtγ, dtγ, repeat_step, true)
4647

4748
calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev,
4849
integrator.opts.abstol, integrator.opts.reltol,
@@ -52,20 +53,20 @@ end
5253
linres = dolinsolve(
5354
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
5455
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
55-
solverdata = (; gamma = γ))
56+
solverdata = (; gamma = dtγ))
5657
else
5758
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
5859
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
59-
solverdata = (; gamma = γ))
60+
solverdata = (; gamma = dtγ))
6061
end
6162

6263
vecu = _vec(linres.u)
6364
veck₁ = _vec(k₁)
6465

65-
@.. broadcast=false veck₁=-vecu
66+
@.. veck₁ = vecu * neginvdtγ
6667
integrator.stats.nsolve += 1
6768

68-
@.. broadcast=false u=uprev + dto2 * k₁
69+
@.. u=uprev + dto2 * k₁
6970
stage_limiter!(u, integrator, p, t + dto2)
7071
f(f₁, u, p, t + dto2)
7172
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
@@ -76,17 +77,16 @@ end
7677
mul!(_vec(tmp), mass_matrix, _vec(k₁))
7778
end
7879

79-
@.. broadcast=false linsolve_tmp=f₁ - tmp
80+
@.. linsolve_tmp = f₁ - tmp
8081

8182
linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp))
8283
vecu = _vec(linres.u)
83-
veck2 = _vec(k₂)
84+
veck₂ = _vec(k₂)
8485

85-
@.. broadcast=false veck2=-vecu
86+
@.. veck₂ = vecu * neginvdtγ + veck₁
8687
integrator.stats.nsolve += 1
8788

88-
@.. broadcast=false k₂+=k₁
89-
@.. broadcast=false u=uprev + dt * k₂
89+
@.. u = uprev + dt * k₂
9090
stage_limiter!(u, integrator, p, t + dt)
9191
step_limiter!(u, integrator, p, t + dt)
9292

@@ -107,7 +107,7 @@ end
107107
linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp))
108108
vecu = _vec(linres.u)
109109
veck3 = _vec(k₃)
110-
@.. broadcast=false veck3=-vecu
110+
@.. veck3 = vecu * neginvdtγ
111111

112112
integrator.stats.nsolve += 1
113113

@@ -127,8 +127,8 @@ end
127127

128128
if mass_matrix !== I
129129
algvar = reshape(cache.algebraic_vars, size(u))
130-
@.. broadcast=false atmp=ifelse(algvar, fsallast, false) /
131-
integrator.opts.abstol
130+
invatol = inv(integrator.opts.abstol)
131+
@.. atmp = ifelse(algvar, fsallast, false) * invatol
132132
integrator.EEst += integrator.opts.internalnorm(atmp, t)
133133
end
134134
end
@@ -145,7 +145,8 @@ end
145145
mass_matrix = integrator.f.mass_matrix
146146

147147
# Precalculations
148-
γ = dt * d
148+
dtγ = dt * d
149+
neginvdtγ = -inv(dtγ)
149150
dto2 = dt / 2
150151
dto6 = dt / 6
151152

@@ -154,7 +155,7 @@ end
154155
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
155156
end
156157

157-
calc_rosenbrock_differentiation!(integrator, cache, γ, γ, repeat_step, false)
158+
calc_rosenbrock_differentiation!(integrator, cache, dtγ, dtγ, repeat_step, true)
158159

159160
calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev,
160161
integrator.opts.abstol, integrator.opts.reltol,
@@ -164,17 +165,17 @@ end
164165
linres = dolinsolve(
165166
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
166167
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
167-
solverdata = (; gamma = γ))
168+
solverdata = (; gamma = dtγ))
168169
else
169170
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
170171
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
171-
solverdata = (; gamma = γ))
172+
solverdata = (; gamma = dtγ))
172173
end
173174

174175
vecu = _vec(linres.u)
175176
veck₁ = _vec(k₁)
176177

177-
@.. broadcast=false veck₁=-vecu
178+
@.. veck₁ = vecu * neginvdtγ
178179
integrator.stats.nsolve += 1
179180

180181
@.. broadcast=false u=uprev + dto2 * k₁
@@ -192,13 +193,12 @@ end
192193

193194
linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp))
194195
vecu = _vec(linres.u)
195-
veck2 = _vec(k₂)
196+
veck₂ = _vec(k₂)
196197

197-
@.. broadcast=false veck2=-vecu
198+
@.. veck₂ = vecu * neginvdtγ + veck₁
198199
integrator.stats.nsolve += 1
199200

200-
@.. broadcast=false k₂+=k₁
201-
@.. broadcast=false tmp=uprev + dt * k₂
201+
@.. tmp = uprev + dt * k₂
202202
stage_limiter!(u, integrator, p, t + dt)
203203
f(fsallast, tmp, p, t + dt)
204204
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
@@ -216,7 +216,7 @@ end
216216
vecu = _vec(linres.u)
217217
veck3 = _vec(k₃)
218218

219-
@.. broadcast=false veck3=-vecu
219+
@.. veck3 = vecu * neginvdtγ
220220
integrator.stats.nsolve += 1
221221

222222
@.. broadcast=false u=uprev + dto6 * (k₁ + 4k₂ + k₃)
@@ -230,8 +230,8 @@ end
230230
integrator.EEst = integrator.opts.internalnorm(atmp, t)
231231

232232
if mass_matrix !== I
233-
@.. broadcast=false atmp=ifelse(cache.algebraic_vars, fsallast, false) /
234-
integrator.opts.abstol
233+
invatol = inv(integrator.opts.abstol)
234+
@.. atmp=ifelse(cache.algebraic_vars, fsallast, false) * invatol
235235
integrator.EEst += integrator.opts.internalnorm(atmp, t)
236236
end
237237
end
@@ -244,7 +244,8 @@ end
244244
@unpack c₃₂, d, tf, uf = cache
245245

246246
# Precalculations
247-
γ = dt * d
247+
dtγ = dt * d
248+
neginvdtγ = -inv(dtγ)
248249
dto2 = dt / 2
249250
dto6 = dt / 6
250251

@@ -258,22 +259,24 @@ end
258259
# Time derivative
259260
dT = calc_tderivative(integrator, cache)
260261

261-
W = calc_W(integrator, cache, γ, repeat_step)
262+
W = calc_W(integrator, cache, dtγ, repeat_step, true)
262263
if !issuccess_W(W)
263264
integrator.EEst = 2
264265
return nothing
265266
end
266267

267-
k₁ = _reshape(W \ -_vec((integrator.fsalfirst + γ * dT)), axes(uprev))
268+
k₁ = _reshape(W \ _vec((integrator.fsalfirst + dtγ * dT)), axes(uprev)) * neginvdtγ
268269
integrator.stats.nsolve += 1
269-
f₁ = f(uprev + dto2 * k₁, p, t + dto2)
270+
tmp = @.. uprev + dto2 * k₁
271+
f₁ = f(tmp, p, t + dto2)
270272
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
271273

272274
if mass_matrix === I
273-
k₂ = _reshape(W \ -_vec(f₁ - k₁), axes(uprev)) + k₁
275+
k₂ = _reshape(W \ _vec(f₁ - k₁), axes(uprev))
274276
else
275-
k₂ = _reshape(W \ -_vec(f₁ - mass_matrix * k₁), axes(uprev)) + k₁
277+
k₂ = _reshape(W \ _vec(f₁ - mass_matrix * k₁), axes(uprev))
276278
end
279+
k₂ = @.. k₂ * neginvdtγ + k₁
277280
integrator.stats.nsolve += 1
278281
u = uprev + dt * k₂
279282

@@ -282,30 +285,28 @@ end
282285
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
283286

284287
if mass_matrix === I
285-
k₃ = _reshape(
286-
W \
287-
-_vec((integrator.fsallast - c₃₂ * (k₂ - f₁) -
288-
2 * (k₁ - integrator.fsalfirst) + dt * dT)),
289-
axes(uprev))
288+
linsolve_tmp = @.. (integrator.fsallast - c₃₂ * (k₂ - f₁) -
289+
2 * (k₁ - integrator.fsalfirst) + dt * dT)
290290
else
291-
linsolve_tmp = integrator.fsallast - mass_matrix * (c₃₂ * k₂ + 2 * k₁) +
292-
c₃₂ * f₁ + 2 * integrator.fsalfirst + dt * dT
293-
k₃ = _reshape(W \ -_vec(linsolve_tmp), axes(uprev))
291+
linsolve_tmp = mass_matrix * (@.. c₃₂ * k₂ + 2 * k₁)
292+
linsolve_tmp = @.. (integrator.fsallast - linsolve_tmp +
293+
c₃₂ * f₁ + 2 * integrator.fsalfirst + dt * dT)
294294
end
295+
k₃ = _reshape(W \ _vec(linsolve_tmp), axes(uprev)) * neginvdtγ
295296
integrator.stats.nsolve += 1
296297

297298
if u isa Number
298299
utilde = dto6 * f.mass_matrix[1, 1] * (k₁ - 2 * k₂ + k₃)
299300
else
300-
utilde = dto6 * f.mass_matrix * (k₁ - 2 * k₂ + k₃)
301+
utilde = f.mass_matrix * (@.. dto6 * (k₁ - 2 * k₂ + k₃))
301302
end
302303
atmp = calculate_residuals(utilde, uprev, u, integrator.opts.abstol,
303304
integrator.opts.reltol, integrator.opts.internalnorm, t)
304305
integrator.EEst = integrator.opts.internalnorm(atmp, t)
305306

306307
if mass_matrix !== I
307-
atmp = @. ifelse(!integrator.differential_vars, integrator.fsallast, false) ./
308-
integrator.opts.abstol
308+
invatol = inv(integrator.opts.abstol)
309+
atmp = @. ifelse(integrator.differential_vars, false, integrator.fsallast) * invatol
309310
integrator.EEst += integrator.opts.internalnorm(atmp, t)
310311
end
311312
end
@@ -321,7 +322,8 @@ end
321322
@unpack c₃₂, d, tf, uf = cache
322323

323324
# Precalculations
324-
γ = dt * d
325+
dtγ = dt * d
326+
neginvdtγ = -inv(dtγ)
325327
dto2 = dt / 2
326328
dto6 = dt / 6
327329

@@ -335,52 +337,52 @@ end
335337
# Time derivative
336338
dT = calc_tderivative(integrator, cache)
337339

338-
W = calc_W(integrator, cache, γ, repeat_step)
340+
W = calc_W(integrator, cache, dtγ, repeat_step, true)
339341
if !issuccess_W(W)
340342
integrator.EEst = 2
341343
return nothing
342344
end
343345

344-
k₁ = _reshape(W \ -_vec((integrator.fsalfirst + γ * dT)), axes(uprev))
346+
k₁ = _reshape(W \ -_vec((integrator.fsalfirst + dtγ * dT)), axes(uprev))/dtγ
345347
integrator.stats.nsolve += 1
346-
f₁ = f(uprev + dto2 * k₁, p, t + dto2)
348+
tmp = @.. uprev + dto2 * k₁
349+
f₁ = f(tmp, p, t + dto2)
347350
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
348351

349352
if mass_matrix === I
350-
k₂ = _reshape(W \ -_vec(f₁ - k₁), axes(uprev)) + k₁
353+
k₂ = _reshape(W \ _vec(f₁ - k₁), axes(uprev))
351354
else
352355
linsolve_tmp = f₁ - mass_matrix * k₁
353-
k₂ = _reshape(W \ -_vec(linsolve_tmp), axes(uprev)) + k₁
356+
k₂ = _reshape(W \ _vec(linsolve_tmp), axes(uprev))
354357
end
358+
k₂ = @.. k₂ * neginvdtγ + k₁
355359

356360
integrator.stats.nsolve += 1
357-
tmp = uprev + dt * k₂
361+
tmp = @.. uprev + dt * k₂
358362
integrator.fsallast = f(tmp, p, t + dt)
359363
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
360364

361365
if mass_matrix === I
362-
k₃ = _reshape(
363-
W \
364-
-_vec((integrator.fsallast - c₃₂ * (k₂ - f₁) -
365-
2(k₁ - integrator.fsalfirst) + dt * dT)),
366-
axes(uprev))
366+
linsolve_tmp = @.. (integrator.fsallast - c₃₂ * (k₂ - f₁) -
367+
2(k₁ - integrator.fsalfirst) + dt * dT)
367368
else
368-
linsolve_tmp = integrator.fsallast - mass_matrix * (c₃₂ * k₂ + 2k₁) + c₃₂ * f₁ +
369-
2 * integrator.fsalfirst + dt * dT
370-
k₃ = _reshape(W \ -_vec(linsolve_tmp), axes(uprev))
369+
linsolve_tmp = mass_matrix * (@.. c₃₂ * k₂ + 2 * k₁)
370+
linsolve_tmp = @.. (integrator.fsallast - linsolve_tmp +
371+
c₃₂ * f₁ + 2 * integrator.fsalfirst + dt * dT)
371372
end
373+
k₃ = _reshape(W \ _vec(linsolve_tmp), axes(uprev)) * neginvdtγ
372374
integrator.stats.nsolve += 1
373-
u = uprev + dto6 * (k₁ + 4k₂ + k₃)
375+
u = @.. uprev + dto6 * (k₁ + 4k₂ + k₃)
374376

375377
if integrator.opts.adaptive
376-
utilde = dto6 * (k₁ - 2k₂ + k₃)
378+
utilde = @.. dto6 * (k₁ - 2k₂ + k₃)
377379
atmp = calculate_residuals(utilde, uprev, u, integrator.opts.abstol,
378380
integrator.opts.reltol, integrator.opts.internalnorm, t)
379381
integrator.EEst = integrator.opts.internalnorm(atmp, t)
380382

381383
if mass_matrix !== I
382-
atmp = @. ifelse(!integrator.differential_vars, integrator.fsallast, false) ./
383-
integrator.opts.abstol
384+
invatol = inv(integrator.opts.abstol)
385+
atmp = ifelse(integrator.differential_vars, false, integrator.fsallast) .* invatol
384386
integrator.EEst += integrator.opts.internalnorm(atmp, t)
385387
end
386388
end

0 commit comments

Comments
 (0)