Skip to content

Commit d4ea754

Browse files
committed
simplify dolinsolve
1 parent 07470f9 commit d4ea754

File tree

3 files changed

+13
-132
lines changed

3 files changed

+13
-132
lines changed

lib/OrdinaryDiffEqDifferentiation/src/linsolve_utils.jl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,19 @@ issuccess_W(W::Number) = !iszero(W)
33
issuccess_W(::Any) = true
44

55
function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothing,
6-
du = nothing, u = nothing, p = nothing, t = nothing,
7-
weight = nothing, solverdata = nothing,
86
reltol = integrator === nothing ? nothing : integrator.opts.reltol)
97
A !== nothing && (linsolve.A = A)
108
b !== nothing && (linsolve.b = b)
119
linu !== nothing && (linsolve.u = linu)
1210

13-
Plprev = linsolve.Pl isa LinearSolve.ComposePreconditioner ? linsolve.Pl.outer :
14-
linsolve.Pl
15-
Prprev = linsolve.Pr isa LinearSolve.ComposePreconditioner ? linsolve.Pr.outer :
16-
linsolve.Pr
17-
1811
_alg = unwrap_alg(integrator, true)
1912

2013
_Pl, _Pr = _alg.precs(linsolve.A, du, u, p, t, A !== nothing, Plprev, Prprev,
2114
solverdata)
22-
if (_Pl !== nothing || _Pr !== nothing)
23-
__Pl = _Pl === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pl
24-
__Pr = _Pr === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pr
25-
linsolve.Pl = __Pl
26-
linsolve.Pr = __Pr
15+
if !isnothing(A)
16+
(;du, u, p, t) = integrator
17+
p = isnothing(integrator) ? nothing : (du, u, p, t)
18+
reinit!(linsolve; A, p)
2719
end
2820

2921
linres = solve!(linsolve; reltol)

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl

Lines changed: 1 addition & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,7 @@ function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
9797
linsolve_tmp = zero(rate_prototype)
9898

9999
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
100-
Pl, Pr = wrapprecs(
101-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
102-
nothing)..., weight, tmp)
103100
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
104-
Pl = Pl, Pr = Pr,
105101
assumptions = LinearSolve.OperatorAssumptions(true))
106102

107103
grad_config = build_grad_config(alg, f, tf, du1, t)
@@ -143,11 +139,7 @@ function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits},
143139
linsolve_tmp = zero(rate_prototype)
144140
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
145141

146-
Pl, Pr = wrapprecs(
147-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
148-
nothing)..., weight, tmp)
149142
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
150-
Pl = Pl, Pr = Pr,
151143
assumptions = LinearSolve.OperatorAssumptions(true))
152144
grad_config = build_grad_config(alg, f, tf, du1, t)
153145
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2, Val(false))
@@ -291,11 +283,7 @@ function alg_cache(alg::ROS3P, u, rate_prototype, ::Type{uEltypeNoUnits},
291283
uf = UJacobianWrapper(f, t, p)
292284
linsolve_tmp = zero(rate_prototype)
293285
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
294-
Pl, Pr = wrapprecs(
295-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
296-
nothing)..., weight, tmp)
297286
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
298-
Pl = Pl, Pr = Pr,
299287
assumptions = LinearSolve.OperatorAssumptions(true))
300288
grad_config = build_grad_config(alg, f, tf, du1, t)
301289
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -377,11 +365,7 @@ function alg_cache(alg::Rodas3, u, rate_prototype, ::Type{uEltypeNoUnits},
377365
uf = UJacobianWrapper(f, t, p)
378366
linsolve_tmp = zero(rate_prototype)
379367
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
380-
Pl, Pr = wrapprecs(
381-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
382-
nothing)..., weight, tmp)
383368
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
384-
Pl = Pl, Pr = Pr,
385369
assumptions = LinearSolve.OperatorAssumptions(true))
386370
grad_config = build_grad_config(alg, f, tf, du1, t)
387371
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -570,11 +554,7 @@ function alg_cache(alg::Rodas23W, u, rate_prototype, ::Type{uEltypeNoUnits},
570554
uf = UJacobianWrapper(f, t, p)
571555
linsolve_tmp = zero(rate_prototype)
572556
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
573-
Pl, Pr = wrapprecs(
574-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
575-
nothing)..., weight, tmp)
576557
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
577-
Pl = Pl, Pr = Pr,
578558
assumptions = LinearSolve.OperatorAssumptions(true))
579559
grad_config = build_grad_config(alg, f, tf, du1, t)
580560
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -615,11 +595,7 @@ function alg_cache(alg::Rodas3P, u, rate_prototype, ::Type{uEltypeNoUnits},
615595
uf = UJacobianWrapper(f, t, p)
616596
linsolve_tmp = zero(rate_prototype)
617597
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
618-
Pl, Pr = wrapprecs(
619-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
620-
nothing)..., weight, tmp)
621598
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
622-
Pl = Pl, Pr = Pr,
623599
assumptions = LinearSolve.OperatorAssumptions(true))
624600
grad_config = build_grad_config(alg, f, tf, du1, t)
625601
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -740,11 +716,7 @@ function alg_cache(alg::Rodas4, u, rate_prototype, ::Type{uEltypeNoUnits},
740716
uf = UJacobianWrapper(f, t, p)
741717
linsolve_tmp = zero(rate_prototype)
742718
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
743-
Pl, Pr = wrapprecs(
744-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
745-
nothing)..., weight, tmp)
746719
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
747-
Pl = Pl, Pr = Pr,
748720
assumptions = LinearSolve.OperatorAssumptions(true))
749721
grad_config = build_grad_config(alg, f, tf, du1, t)
750722
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -802,11 +774,7 @@ function alg_cache(alg::Rodas42, u, rate_prototype, ::Type{uEltypeNoUnits},
802774
uf = UJacobianWrapper(f, t, p)
803775
linsolve_tmp = zero(rate_prototype)
804776
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
805-
Pl, Pr = wrapprecs(
806-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
807-
nothing)..., weight, tmp)
808777
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
809-
Pl = Pl, Pr = Pr,
810778
assumptions = LinearSolve.OperatorAssumptions(true))
811779
grad_config = build_grad_config(alg, f, tf, du1, t)
812780
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -862,11 +830,7 @@ function alg_cache(alg::Rodas4P, u, rate_prototype, ::Type{uEltypeNoUnits},
862830
uf = UJacobianWrapper(f, t, p)
863831
linsolve_tmp = zero(rate_prototype)
864832
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
865-
Pl, Pr = wrapprecs(
866-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
867-
nothing)..., weight, tmp)
868833
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
869-
Pl = Pl, Pr = Pr,
870834
assumptions = LinearSolve.OperatorAssumptions(true))
871835
grad_config = build_grad_config(alg, f, tf, du1, t)
872836
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -922,11 +886,7 @@ function alg_cache(alg::Rodas4P2, u, rate_prototype, ::Type{uEltypeNoUnits},
922886
uf = UJacobianWrapper(f, t, p)
923887
linsolve_tmp = zero(rate_prototype)
924888
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
925-
Pl, Pr = wrapprecs(
926-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
927-
nothing)..., weight, tmp)
928889
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
929-
Pl = Pl, Pr = Pr,
930890
assumptions = LinearSolve.OperatorAssumptions(true))
931891
grad_config = build_grad_config(alg, f, tf, du1, t)
932892
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -1041,11 +1001,7 @@ function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
10411001
uf = UJacobianWrapper(f, t, p)
10421002
linsolve_tmp = zero(rate_prototype)
10431003
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
1044-
Pl, Pr = wrapprecs(
1045-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
1046-
nothing)..., weight, tmp)
10471004
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
1048-
Pl = Pl, Pr = Pr,
10491005
assumptions = LinearSolve.OperatorAssumptions(true))
10501006
grad_config = build_grad_config(alg, f, tf, du1, t)
10511007
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -1105,11 +1061,7 @@ function alg_cache(
11051061
uf = UJacobianWrapper(f, t, p)
11061062
linsolve_tmp = zero(rate_prototype)
11071063
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
1108-
Pl, Pr = wrapprecs(
1109-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
1110-
nothing)..., weight, tmp)
11111064
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
1112-
Pl = Pl, Pr = Pr,
11131065
assumptions = LinearSolve.OperatorAssumptions(true))
11141066
grad_config = build_grad_config(alg, f, tf, du1, t)
11151067
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -1140,4 +1092,4 @@ end
11401092

11411093
### RosenbrockW6S4O
11421094

1143-
@RosenbrockW6S4OS(:cache)
1095+
@RosenbrockW6S4OS(:cache)

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl

Lines changed: 8 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,7 @@ end
5050
integrator.opts.abstol, integrator.opts.reltol,
5151
integrator.opts.internalnorm, t)
5252

53-
if repeat_step
54-
linres = dolinsolve(
55-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
56-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
57-
solverdata = (; gamma = γ))
58-
else
59-
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
60-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
61-
solverdata = (; gamma = γ))
62-
end
53+
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))
6354

6455
vecu = _vec(linres.u)
6556
veck₁ = _vec(k₁)
@@ -162,16 +153,7 @@ end
162153
integrator.opts.abstol, integrator.opts.reltol,
163154
integrator.opts.internalnorm, t)
164155

165-
if repeat_step
166-
linres = dolinsolve(
167-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
168-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
169-
solverdata = (; gamma = γ))
170-
else
171-
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
172-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
173-
solverdata = (; gamma = γ))
174-
end
156+
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))
175157

176158
vecu = _vec(linres.u)
177159
veck₁ = _vec(k₁)
@@ -521,16 +503,7 @@ end
521503
integrator.opts.abstol, integrator.opts.reltol,
522504
integrator.opts.internalnorm, t)
523505

524-
if repeat_step
525-
linres = dolinsolve(
526-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
527-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
528-
solverdata = (; gamma = dtgamma))
529-
else
530-
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
531-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
532-
solverdata = (; gamma = dtgamma))
533-
end
506+
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))
534507

535508
vecu = _vec(linres.u)
536509
veck1 = _vec(k1)
@@ -716,16 +689,7 @@ end
716689
integrator.opts.abstol, integrator.opts.reltol,
717690
integrator.opts.internalnorm, t)
718691

719-
if repeat_step
720-
linres = dolinsolve(
721-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
722-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
723-
solverdata = (; gamma = dtgamma))
724-
else
725-
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
726-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
727-
solverdata = (; gamma = dtgamma))
728-
end
692+
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))
729693

730694
vecu = _vec(linres.u)
731695
veck1 = _vec(k1)
@@ -1024,16 +988,7 @@ end
1024988
integrator.opts.abstol, integrator.opts.reltol,
1025989
integrator.opts.internalnorm, t)
1026990

1027-
if repeat_step
1028-
linres = dolinsolve(
1029-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
1030-
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
1031-
solverdata = (; gamma = dtgamma))
1032-
else
1033-
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
1034-
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
1035-
solverdata = (; gamma = dtgamma))
1036-
end
991+
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))
1037992

1038993
@.. broadcast=false $(_vec(k1))=-linres.u
1039994

@@ -1387,16 +1342,7 @@ end
13871342
integrator.opts.abstol, integrator.opts.reltol,
13881343
integrator.opts.internalnorm, t)
13891344

1390-
if repeat_step
1391-
linres = dolinsolve(
1392-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
1393-
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
1394-
solverdata = (; gamma = dtgamma))
1395-
else
1396-
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
1397-
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
1398-
solverdata = (; gamma = dtgamma))
1399-
end
1345+
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))
14001346

14011347
@.. broadcast=false $(_vec(k1))=-linres.u
14021348

@@ -1790,16 +1736,7 @@ end
17901736
integrator.opts.abstol, integrator.opts.reltol,
17911737
integrator.opts.internalnorm, t)
17921738

1793-
if repeat_step
1794-
linres = dolinsolve(
1795-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
1796-
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
1797-
solverdata = (; gamma = dtgamma))
1798-
else
1799-
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
1800-
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
1801-
solverdata = (; gamma = dtgamma))
1802-
end
1739+
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))
18031740

18041741
vecu = _vec(linres.u)
18051742
veck1 = _vec(k1)
@@ -1986,4 +1923,4 @@ end
19861923
end
19871924

19881925
@RosenbrockW6S4OS(:init)
1989-
@RosenbrockW6S4OS(:performstep)
1926+
@RosenbrockW6S4OS(:performstep)

0 commit comments

Comments
 (0)