Skip to content

Commit b0f957c

Browse files
Merge pull request #2297 from oscardssmith/os/always-w-transform
Make W_transform always true
2 parents 09aa469 + 783c88a commit b0f957c

File tree

8 files changed

+97
-265
lines changed

8 files changed

+97
-265
lines changed

lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl

Lines changed: 55 additions & 217 deletions
Large diffs are not rendered by default.

lib/OrdinaryDiffEqDifferentiation/src/derivative_wrappers.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,13 +258,11 @@ function jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number},
258258
nothing
259259
end
260260

261-
function build_jac_config(alg, f::F1, uf::F2, du1, uprev, u, tmp, du2,
262-
::Val{transform} = Val(true)) where {transform, F1, F2}
261+
function build_jac_config(alg, f::F1, uf::F2, du1, uprev, u, tmp, du2) where {F1, F2}
263262
haslinsolve = hasfield(typeof(alg), :linsolve)
264263

265264
if !DiffEqBase.has_jac(f) && # No Jacobian if has analytical solution
266-
(transform || !DiffEqBase.has_Wfact(f)) && # No Jacobian if has_Wfact and Wfact is the one that's used
267-
(!transform || !DiffEqBase.has_Wfact_t(f)) && # No Jacobian has_Wfact and Wfact_t is the one that's used
265+
(!DiffEqBase.has_Wfact_t(f)) &&
268266
((concrete_jac(alg) === nothing && (!haslinsolve || (haslinsolve && # No Jacobian if linsolve doesn't want it
269267
(alg.linsolve === nothing || LinearSolve.needs_concrete_A(alg.linsolve))))) ||
270268
(concrete_jac(alg) !== nothing && concrete_jac(alg))) # Jacobian if explicitly asked for

lib/OrdinaryDiffEqExtrapolation/src/extrapolation_perform_step.jl

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ function perform_step!(integrator, cache::ImplicitEulerExtrapolationCache,
290290
calc_J!(J, integrator, cache) # Store the calculated jac as it won't change in internal discretisation
291291
for index in 1:(n_curr + 1)
292292
dt_temp = dt / sequence[index]
293-
jacobian2W!(W[1], integrator.f.mass_matrix, dt_temp, J, true)
293+
jacobian2W!(W[1], integrator.f.mass_matrix, dt_temp, J)
294294
integrator.stats.nw += 1
295295
@.. broadcast=false k_tmps[1]=integrator.fsalfirst
296296
@.. broadcast=false u_tmps[1]=uprev
@@ -344,9 +344,7 @@ function perform_step!(integrator, cache::ImplicitEulerExtrapolationCache,
344344
endIndex = (i == 1) ? n_curr : n_curr + 1
345345
for index in startIndex:endIndex
346346
dt_temp = dt / sequence[index]
347-
jacobian2W!(
348-
W[Threads.threadid()], integrator.f.mass_matrix, dt_temp, J,
349-
true)
347+
jacobian2W!(W[Threads.threadid()], integrator.f.mass_matrix, dt_temp, J)
350348
@.. broadcast=false k_tmps[Threads.threadid()]=integrator.fsalfirst
351349
@.. broadcast=false u_tmps[Threads.threadid()]=uprev
352350
for j in 1:sequence[index]
@@ -445,7 +443,7 @@ function perform_step!(integrator, cache::ImplicitEulerExtrapolationCache,
445443
cache.n_curr = n_curr
446444

447445
dt_temp = dt / sequence[n_curr + 1]
448-
jacobian2W!(W[1], integrator.f.mass_matrix, dt_temp, J, false)
446+
jacobian2W!(W[1], integrator.f.mass_matrix, dt_temp, J)
449447
integrator.stats.nw += 1
450448
@.. broadcast=false k_tmps[1]=integrator.fsalfirst
451449
@.. broadcast=false u_tmps[1]=uprev
@@ -1170,7 +1168,7 @@ function perform_step!(integrator, cache::ImplicitDeuflhardExtrapolationCache,
11701168
for i in 0:n_curr
11711169
j_int = 4 * subdividing_sequence[i + 1]
11721170
dt_int = dt / j_int # Stepsize of the ith internal discretisation
1173-
jacobian2W!(W[1], integrator.f.mass_matrix, dt_int, J, true)
1171+
jacobian2W!(W[1], integrator.f.mass_matrix, dt_int, J)
11741172
integrator.stats.nw += 1
11751173
@.. broadcast=false u_temp2=uprev
11761174
@.. broadcast=false linsolve_tmps[1]=fsalfirst
@@ -1241,7 +1239,7 @@ function perform_step!(integrator, cache::ImplicitDeuflhardExtrapolationCache,
12411239
j_int_temp = 4 * subdividing_sequence[index + 1]
12421240
dt_int_temp = dt / j_int_temp # Stepsize of the ith internal discretisation
12431241
jacobian2W!(W[Threads.threadid()], integrator.f.mass_matrix,
1244-
dt_int_temp, J, true)
1242+
dt_int_temp, J)
12451243
@.. broadcast=false u_temp4[Threads.threadid()]=uprev
12461244
@.. broadcast=false linsolve_tmps[Threads.threadid()]=fsalfirst
12471245

@@ -1326,7 +1324,7 @@ function perform_step!(integrator, cache::ImplicitDeuflhardExtrapolationCache,
13261324
j_int_temp = 4 * subdividing_sequence[index + 1]
13271325
dt_int_temp = dt / j_int_temp # Stepsize of the ith internal discretisation
13281326
jacobian2W!(W[Threads.threadid()], integrator.f.mass_matrix,
1329-
dt_int_temp, J, true)
1327+
dt_int_temp, J)
13301328
@.. broadcast=false u_temp4[Threads.threadid()]=uprev
13311329
@.. broadcast=false linsolve_tmps[Threads.threadid()]=fsalfirst
13321330

@@ -1450,7 +1448,7 @@ function perform_step!(integrator, cache::ImplicitDeuflhardExtrapolationCache,
14501448
# Update cache.T
14511449
j_int = 4 * subdividing_sequence[n_curr + 1]
14521450
dt_int = dt / j_int # Stepsize of the new internal discretisation
1453-
jacobian2W!(W[1], integrator.f.mass_matrix, dt_int, J, true)
1451+
jacobian2W!(W[1], integrator.f.mass_matrix, dt_int, J)
14541452
integrator.stats.nw += 1
14551453
@.. broadcast=false u_temp2=uprev
14561454
@.. broadcast=false linsolve_tmps[1]=fsalfirst
@@ -2536,7 +2534,7 @@ function perform_step!(integrator, cache::ImplicitHairerWannerExtrapolationCache
25362534
for i in 0:n_curr
25372535
j_int = 4 * subdividing_sequence[i + 1]
25382536
dt_int = dt / j_int # Stepsize of the ith internal discretisation
2539-
jacobian2W!(W[1], integrator.f.mass_matrix, dt_int, J, true)
2537+
jacobian2W!(W[1], integrator.f.mass_matrix, dt_int, J)
25402538
integrator.stats.nw += 1
25412539
@.. broadcast=false u_temp2=uprev
25422540
@.. broadcast=false linsolve_tmps[1]=fsalfirst
@@ -2610,7 +2608,7 @@ function perform_step!(integrator, cache::ImplicitHairerWannerExtrapolationCache
26102608
j_int_temp = 4 * subdividing_sequence[index + 1]
26112609
dt_int_temp = dt / j_int_temp # Stepsize of the ith internal discretisation
26122610
jacobian2W!(W[Threads.threadid()], integrator.f.mass_matrix,
2613-
dt_int_temp, J, true)
2611+
dt_int_temp, J)
26142612
@.. broadcast=false u_temp4[Threads.threadid()]=uprev
26152613
@.. broadcast=false linsolve_tmps[Threads.threadid()]=fsalfirst
26162614

@@ -2701,7 +2699,7 @@ function perform_step!(integrator, cache::ImplicitHairerWannerExtrapolationCache
27012699
index == -1 && continue
27022700
j_int_temp = 4 * subdividing_sequence[index + 1]
27032701
dt_int_temp = dt / j_int_temp # Stepsize of the ith internal discretisation
2704-
jacobian2W!(W[tid], integrator.f.mass_matrix, dt_int_temp, J, true)
2702+
jacobian2W!(W[tid], integrator.f.mass_matrix, dt_int_temp, J)
27052703
@.. broadcast=false u_temp4[tid]=uprev
27062704
@.. broadcast=false linsolvetmp=fsalfirst
27072705

@@ -2815,7 +2813,7 @@ function perform_step!(integrator, cache::ImplicitHairerWannerExtrapolationCache
28152813
# Update cache.T
28162814
j_int = 4 * subdividing_sequence[n_curr + 1]
28172815
dt_int = dt / j_int # Stepsize of the new internal discretisation
2818-
jacobian2W!(W[1], integrator.f.mass_matrix, dt_int, J, true)
2816+
jacobian2W!(W[1], integrator.f.mass_matrix, dt_int, J)
28192817
integrator.stats.nw += 1
28202818
@.. broadcast=false u_temp2=uprev
28212819
@.. broadcast=false linsolve_tmps[1]=fsalfirst
@@ -3227,7 +3225,7 @@ function perform_step!(integrator, cache::ImplicitEulerBarycentricExtrapolationC
32273225
for i in 0:n_curr
32283226
j_int = sequence_factor * subdividing_sequence[i + 1]
32293227
dt_int = dt / j_int # Stepsize of the ith internal discretisation
3230-
jacobian2W!(W[1], integrator.f.mass_matrix, dt_int, J, true)
3228+
jacobian2W!(W[1], integrator.f.mass_matrix, dt_int, J)
32313229
integrator.stats.nw += 1
32323230
@.. broadcast=false u_temp2=uprev
32333231
@.. broadcast=false linsolve_tmps[1]=fsalfirst
@@ -3301,7 +3299,7 @@ function perform_step!(integrator, cache::ImplicitEulerBarycentricExtrapolationC
33013299
j_int_temp = sequence_factor * subdividing_sequence[index + 1]
33023300
dt_int_temp = dt / j_int_temp # Stepsize of the ith internal discretisation
33033301
jacobian2W!(W[Threads.threadid()], integrator.f.mass_matrix,
3304-
dt_int_temp, J, true)
3302+
dt_int_temp, J)
33053303
@.. broadcast=false u_temp4[Threads.threadid()]=uprev
33063304
@.. broadcast=false linsolve_tmps[Threads.threadid()]=fsalfirst
33073305

@@ -3389,7 +3387,7 @@ function perform_step!(integrator, cache::ImplicitEulerBarycentricExtrapolationC
33893387
j_int_temp = sequence_factor * subdividing_sequence[index + 1]
33903388
dt_int_temp = dt / j_int_temp # Stepsize of the ith internal discretisation
33913389
jacobian2W!(W[Threads.threadid()], integrator.f.mass_matrix,
3392-
dt_int_temp, J, true)
3390+
dt_int_temp, J)
33933391
@.. broadcast=false u_temp4[Threads.threadid()]=uprev
33943392
@.. broadcast=false linsolve_tmps[Threads.threadid()]=fsalfirst
33953393

@@ -3519,7 +3517,7 @@ function perform_step!(integrator, cache::ImplicitEulerBarycentricExtrapolationC
35193517
# Update cache.T
35203518
j_int = sequence_factor * subdividing_sequence[n_curr + 1]
35213519
dt_int = dt / j_int # Stepsize of the new internal discretisation
3522-
jacobian2W!(W[1], integrator.f.mass_matrix, dt_int, J, true)
3520+
jacobian2W!(W[1], integrator.f.mass_matrix, dt_int, J)
35233521
integrator.stats.nw += 1
35243522
@.. broadcast=false u_temp2=uprev
35253523
@.. broadcast=false linsolve_tmps[1]=fsalfirst

lib/OrdinaryDiffEqRosenbrock/src/generic_rosenbrock.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -358,9 +358,9 @@ function gen_constant_perform_step(tabmask::RosenbrockTableau{Bool,Bool},cachena
358358

359359
# Time derivative
360360
tf.u = uprev
361-
dT = ForwardDiff.derivative(tf, t)
361+
dT = calc_tderivative(integrator, cache)
362362

363-
W = calc_W(integrator, cache, dtgamma, repeat_step, true)
363+
W = calc_W(integrator, cache, dtgamma, repeat_step)
364364
linsolve_tmp = integrator.fsalfirst + dtd1*dT #calc_rosenbrock_differentiation!
365365

366366
$(iterexprs...)
@@ -476,7 +476,7 @@ function gen_perform_step(tabmask::RosenbrockTableau{Bool,Bool},cachename::Symbo
476476
calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev,
477477
integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t)
478478

479-
calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step, true)
479+
calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step)
480480

481481
linsolve = cache.linsolve
482482

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ end
4343
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
4444
end
4545

46-
calc_rosenbrock_differentiation!(integrator, cache, dtγ, dtγ, repeat_step, true)
46+
calc_rosenbrock_differentiation!(integrator, cache, dtγ, dtγ, repeat_step)
4747

4848
calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev,
4949
integrator.opts.abstol, integrator.opts.reltol,
@@ -155,7 +155,7 @@ end
155155
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
156156
end
157157

158-
calc_rosenbrock_differentiation!(integrator, cache, dtγ, dtγ, repeat_step, true)
158+
calc_rosenbrock_differentiation!(integrator, cache, dtγ, dtγ, repeat_step)
159159

160160
calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev,
161161
integrator.opts.abstol, integrator.opts.reltol,
@@ -259,7 +259,7 @@ end
259259
# Time derivative
260260
dT = calc_tderivative(integrator, cache)
261261

262-
W = calc_W(integrator, cache, dtγ, repeat_step, true)
262+
W = calc_W(integrator, cache, dtγ, repeat_step)
263263
if !issuccess_W(W)
264264
integrator.EEst = 2
265265
return nothing
@@ -338,7 +338,7 @@ end
338338
# Time derivative
339339
dT = calc_tderivative(integrator, cache)
340340

341-
W = calc_W(integrator, cache, dtγ, repeat_step, true)
341+
W = calc_W(integrator, cache, dtγ, repeat_step)
342342
if !issuccess_W(W)
343343
integrator.EEst = 2
344344
return nothing
@@ -444,7 +444,7 @@ end
444444
# Time derivative
445445
dT = calc_tderivative(integrator, cache)
446446

447-
W = calc_W(integrator, cache, dtgamma, repeat_step, true)
447+
W = calc_W(integrator, cache, dtgamma, repeat_step)
448448
if !issuccess_W(W)
449449
integrator.EEst = 2
450450
return nothing
@@ -515,7 +515,7 @@ end
515515
dtd3 = dt * d3
516516
dtgamma = dt * gamma
517517

518-
calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step, true)
518+
calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step)
519519

520520
calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev,
521521
integrator.opts.abstol, integrator.opts.reltol,
@@ -623,7 +623,7 @@ end
623623
tf.u = uprev
624624
dT = calc_tderivative(integrator, cache)
625625

626-
W = calc_W(integrator, cache, dtgamma, repeat_step, true)
626+
W = calc_W(integrator, cache, dtgamma, repeat_step)
627627
if !issuccess_W(W)
628628
integrator.EEst = 2
629629
return nothing
@@ -710,7 +710,7 @@ end
710710
dtd4 = dt * d4
711711
dtgamma = dt * gamma
712712

713-
calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step, true)
713+
calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step)
714714

715715
calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev,
716716
integrator.opts.abstol, integrator.opts.reltol,
@@ -876,7 +876,7 @@ end
876876
tf.u = uprev
877877
dT = calc_tderivative(integrator, cache)
878878

879-
W = calc_W(integrator, cache, dtgamma, repeat_step, true)
879+
W = calc_W(integrator, cache, dtgamma, repeat_step)
880880
if !issuccess_W(W)
881881
integrator.EEst = 2
882882
return nothing
@@ -1018,7 +1018,7 @@ end
10181018
f(cache.fsalfirst, uprev, p, t) # used in calc_rosenbrock_differentiation!
10191019
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
10201020

1021-
calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step, true)
1021+
calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step)
10221022

10231023
calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev,
10241024
integrator.opts.abstol, integrator.opts.reltol,
@@ -1226,7 +1226,7 @@ end
12261226
tf.u = uprev
12271227
dT = calc_tderivative(integrator, cache)
12281228

1229-
W = calc_W(integrator, cache, dtgamma, repeat_step, true)
1229+
W = calc_W(integrator, cache, dtgamma, repeat_step)
12301230
if !issuccess_W(W)
12311231
integrator.EEst = 2
12321232
return nothing
@@ -1317,7 +1317,7 @@ end
13171317
f(cache.fsalfirst, uprev, p, t)
13181318
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
13191319

1320-
calc_rosenbrock_differentiation!(integrator, cache, dtd[1], dtgamma, repeat_step, true)
1320+
calc_rosenbrock_differentiation!(integrator, cache, dtd[1], dtgamma, repeat_step)
13211321

13221322
calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev,
13231323
integrator.opts.abstol, integrator.opts.reltol,
@@ -1449,7 +1449,7 @@ end
14491449
# Time derivative
14501450
dT = calc_tderivative(integrator, cache)
14511451

1452-
W = calc_W(integrator, cache, dtgamma, repeat_step, true)
1452+
W = calc_W(integrator, cache, dtgamma, repeat_step)
14531453
if !issuccess_W(W)
14541454
integrator.EEst = 2
14551455
return nothing
@@ -1662,7 +1662,7 @@ end
16621662
f(cache.fsalfirst, uprev, p, t) # used in calc_rosenbrock_differentiation!
16631663
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
16641664

1665-
calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step, true)
1665+
calc_rosenbrock_differentiation!(integrator, cache, dtd1, dtgamma, repeat_step)
16661666

16671667
calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev,
16681668
integrator.opts.abstol, integrator.opts.reltol,

lib/OrdinaryDiffEqRosenbrock/src/stiff_addsteps.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p,
6363

6464
### Jacobian does not need to be re-evaluated after an event
6565
### Since it's unchanged
66-
jacobian2W!(W, mass_matrix, dtγ, J, true)
66+
jacobian2W!(W, mass_matrix, dtγ, J)
6767

6868
linsolve = cache.linsolve
6969

@@ -215,7 +215,7 @@ function _ode_addsteps!(
215215

216216
### Jacobian does not need to be re-evaluated after an event
217217
### Since it's unchanged
218-
jacobian2W!(W, mass_matrix, dtgamma, J, true)
218+
jacobian2W!(W, mass_matrix, dtgamma, J)
219219

220220
linsolve = cache.linsolve
221221

@@ -394,7 +394,7 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::RosenbrockCache,
394394
@.. linsolve_tmp = @muladd fsalfirst + dtgamma * dT
395395

396396
# Jacobian does not need to be re-evaluated after an event since it's unchanged
397-
jacobian2W!(W, mass_matrix, dtgamma, J, true)
397+
jacobian2W!(W, mass_matrix, dtgamma, J)
398398

399399
linsolve = cache.linsolve
400400

@@ -623,7 +623,7 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p, cache::Rosenbrock5Cache,
623623

624624
### Jacobian does not need to be re-evaluated after an event
625625
### Since it's unchanged
626-
jacobian2W!(W, mass_matrix, dtgamma, J, true)
626+
jacobian2W!(W, mass_matrix, dtgamma, J)
627627

628628
linsolve = cache.linsolve
629629

test/interface/utility_tests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using OrdinaryDiffEq.OrdinaryDiffEqDifferentiation: WOperator, calc_W, calc_W!,
1010
tspan = (0.0, 1.0)
1111
dt = 0.01
1212
dtgamma = 0.5dt
13-
concrete_W = -mm + dtgamma * A
13+
concrete_W = A - inv(dtgamma)*mm
1414

1515
# Out-of-place
1616
fun = ODEFunction((u, p, t) -> A * u;
@@ -39,7 +39,7 @@ using OrdinaryDiffEq.OrdinaryDiffEqDifferentiation: WOperator, calc_W, calc_W!,
3939

4040
# But jacobian2W! will update the cache
4141
jacobian2W!(integrator.cache.nlsolver.cache.W._concrete_form, mm,
42-
dtgamma, integrator.cache.nlsolver.cache.W.J.A, false)
42+
dtgamma, integrator.cache.nlsolver.cache.W.J.A)
4343
@test convert(AbstractMatrix, integrator.cache.nlsolver.cache.W) == concrete_W
4444
ldiv!(tmp, lu!(integrator.cache.nlsolver.cache.W), u0)
4545
@test tmp == concrete_W \ u0

test/interface/wprototype_tests.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@ for prob in (prob_ode_vanderpol_stiff,)
1717
update_func = (old_val, u, p, t; dtgamma) -> dtgamma,
1818
accepted_kwargs = (:dtgamma,))
1919
transform_op = ScalarOperator(0.0;
20-
update_func = (old_op, u, p, t; dtgamma, transform) -> transform ?
21-
inv(dtgamma) :
22-
one(dtgamma),
23-
accepted_kwargs = (:dtgamma, :transform))
20+
update_func = (old_op, u, p, t; dtgamma) -> inv(dtgamma),
21+
accepted_kwargs = (:dtgamma,))
2422
W_op = -(I - gamma_op * J_op) * transform_op
2523

2624
# Make problem with custom MatrixOperator jac_prototype

0 commit comments

Comments
 (0)