Skip to content

Commit dda2b9e

Browse files
committed
simplify dolinsolve
1 parent b609a4d commit dda2b9e

File tree

3 files changed

+11
-130
lines changed

3 files changed

+11
-130
lines changed

lib/OrdinaryDiffEqDifferentiation/src/linsolve_utils.jl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,19 @@ issuccess_W(W::Number) = !iszero(W)
33
issuccess_W(::Any) = true
44

55
function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothing,
6-
du = nothing, u = nothing, p = nothing, t = nothing,
7-
weight = nothing, solverdata = nothing,
86
reltol = integrator === nothing ? nothing : integrator.opts.reltol)
97
A !== nothing && (linsolve.A = A)
108
b !== nothing && (linsolve.b = b)
119
linu !== nothing && (linsolve.u = linu)
1210

13-
Plprev = linsolve.Pl isa LinearSolve.ComposePreconditioner ? linsolve.Pl.outer :
14-
linsolve.Pl
15-
Prprev = linsolve.Pr isa LinearSolve.ComposePreconditioner ? linsolve.Pr.outer :
16-
linsolve.Pr
17-
1811
_alg = unwrap_alg(integrator, true)
1912

2013
_Pl, _Pr = _alg.precs(linsolve.A, du, u, p, t, A !== nothing, Plprev, Prprev,
2114
solverdata)
22-
if (_Pl !== nothing || _Pr !== nothing)
23-
__Pl = _Pl === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pl
24-
__Pr = _Pr === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pr
25-
linsolve.Pl = __Pl
26-
linsolve.Pr = __Pr
15+
if !isnothing(A)
16+
(;du, u, p, t) = integrator
17+
p = isnothing(integrator) ? nothing : (du, u, p, t)
18+
reinit!(linsolve; A, p)
2719
end
2820

2921
linres = solve!(linsolve; reltol)

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,7 @@ function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
100100
linsolve_tmp = zero(rate_prototype)
101101

102102
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
103-
Pl, Pr = wrapprecs(
104-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
105-
nothing)..., weight, tmp)
106103
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
107-
Pl = Pl, Pr = Pr,
108104
assumptions = LinearSolve.OperatorAssumptions(true))
109105

110106
grad_config = build_grad_config(alg, f, tf, du1, t)
@@ -146,11 +142,7 @@ function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits},
146142
linsolve_tmp = zero(rate_prototype)
147143
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
148144

149-
Pl, Pr = wrapprecs(
150-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
151-
nothing)..., weight, tmp)
152145
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
153-
Pl = Pl, Pr = Pr,
154146
assumptions = LinearSolve.OperatorAssumptions(true))
155147
grad_config = build_grad_config(alg, f, tf, du1, t)
156148
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2, Val(false))
@@ -294,11 +286,7 @@ function alg_cache(alg::ROS3P, u, rate_prototype, ::Type{uEltypeNoUnits},
294286
uf = UJacobianWrapper(f, t, p)
295287
linsolve_tmp = zero(rate_prototype)
296288
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
297-
Pl, Pr = wrapprecs(
298-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
299-
nothing)..., weight, tmp)
300289
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
301-
Pl = Pl, Pr = Pr,
302290
assumptions = LinearSolve.OperatorAssumptions(true))
303291
grad_config = build_grad_config(alg, f, tf, du1, t)
304292
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -380,11 +368,7 @@ function alg_cache(alg::Rodas3, u, rate_prototype, ::Type{uEltypeNoUnits},
380368
uf = UJacobianWrapper(f, t, p)
381369
linsolve_tmp = zero(rate_prototype)
382370
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
383-
Pl, Pr = wrapprecs(
384-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
385-
nothing)..., weight, tmp)
386371
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
387-
Pl = Pl, Pr = Pr,
388372
assumptions = LinearSolve.OperatorAssumptions(true))
389373
grad_config = build_grad_config(alg, f, tf, du1, t)
390374
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -573,11 +557,7 @@ function alg_cache(alg::Rodas23W, u, rate_prototype, ::Type{uEltypeNoUnits},
573557
uf = UJacobianWrapper(f, t, p)
574558
linsolve_tmp = zero(rate_prototype)
575559
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
576-
Pl, Pr = wrapprecs(
577-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
578-
nothing)..., weight, tmp)
579560
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
580-
Pl = Pl, Pr = Pr,
581561
assumptions = LinearSolve.OperatorAssumptions(true))
582562
grad_config = build_grad_config(alg, f, tf, du1, t)
583563
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -617,11 +597,7 @@ function alg_cache(alg::Rodas3P, u, rate_prototype, ::Type{uEltypeNoUnits},
617597
uf = UJacobianWrapper(f, t, p)
618598
linsolve_tmp = zero(rate_prototype)
619599
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
620-
Pl, Pr = wrapprecs(
621-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
622-
nothing)..., weight, tmp)
623600
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
624-
Pl = Pl, Pr = Pr,
625601
assumptions = LinearSolve.OperatorAssumptions(true))
626602
grad_config = build_grad_config(alg, f, tf, du1, t)
627603
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -740,11 +716,7 @@ function alg_cache(alg::Rodas4, u, rate_prototype, ::Type{uEltypeNoUnits},
740716
uf = UJacobianWrapper(f, t, p)
741717
linsolve_tmp = zero(rate_prototype)
742718
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
743-
Pl, Pr = wrapprecs(
744-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
745-
nothing)..., weight, tmp)
746719
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
747-
Pl = Pl, Pr = Pr,
748720
assumptions = LinearSolve.OperatorAssumptions(true))
749721
grad_config = build_grad_config(alg, f, tf, du1, t)
750722
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -800,11 +772,7 @@ function alg_cache(alg::Rodas42, u, rate_prototype, ::Type{uEltypeNoUnits},
800772
uf = UJacobianWrapper(f, t, p)
801773
linsolve_tmp = zero(rate_prototype)
802774
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
803-
Pl, Pr = wrapprecs(
804-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
805-
nothing)..., weight, tmp)
806775
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
807-
Pl = Pl, Pr = Pr,
808776
assumptions = LinearSolve.OperatorAssumptions(true))
809777
grad_config = build_grad_config(alg, f, tf, du1, t)
810778
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -860,11 +828,7 @@ function alg_cache(alg::Rodas4P, u, rate_prototype, ::Type{uEltypeNoUnits},
860828
uf = UJacobianWrapper(f, t, p)
861829
linsolve_tmp = zero(rate_prototype)
862830
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
863-
Pl, Pr = wrapprecs(
864-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
865-
nothing)..., weight, tmp)
866831
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
867-
Pl = Pl, Pr = Pr,
868832
assumptions = LinearSolve.OperatorAssumptions(true))
869833
grad_config = build_grad_config(alg, f, tf, du1, t)
870834
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -920,11 +884,7 @@ function alg_cache(alg::Rodas4P2, u, rate_prototype, ::Type{uEltypeNoUnits},
920884
uf = UJacobianWrapper(f, t, p)
921885
linsolve_tmp = zero(rate_prototype)
922886
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
923-
Pl, Pr = wrapprecs(
924-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
925-
nothing)..., weight, tmp)
926887
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
927-
Pl = Pl, Pr = Pr,
928888
assumptions = LinearSolve.OperatorAssumptions(true))
929889
grad_config = build_grad_config(alg, f, tf, du1, t)
930890
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -1037,11 +997,7 @@ function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
1037997
uf = UJacobianWrapper(f, t, p)
1038998
linsolve_tmp = zero(rate_prototype)
1039999
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
1040-
Pl, Pr = wrapprecs(
1041-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
1042-
nothing)..., weight, tmp)
10431000
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
1044-
Pl = Pl, Pr = Pr,
10451001
assumptions = LinearSolve.OperatorAssumptions(true))
10461002
grad_config = build_grad_config(alg, f, tf, du1, t)
10471003
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -1101,11 +1057,7 @@ function alg_cache(
11011057
uf = UJacobianWrapper(f, t, p)
11021058
linsolve_tmp = zero(rate_prototype)
11031059
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
1104-
Pl, Pr = wrapprecs(
1105-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
1106-
nothing)..., weight, tmp)
11071060
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
1108-
Pl = Pl, Pr = Pr,
11091061
assumptions = LinearSolve.OperatorAssumptions(true))
11101062
grad_config = build_grad_config(alg, f, tf, du1, t)
11111063
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl

Lines changed: 7 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,7 @@ end
4848
integrator.opts.abstol, integrator.opts.reltol,
4949
integrator.opts.internalnorm, t)
5050

51-
if repeat_step
52-
linres = dolinsolve(
53-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
54-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
55-
solverdata = (; gamma = γ))
56-
else
57-
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
58-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
59-
solverdata = (; gamma = γ))
60-
end
51+
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))
6152

6253
vecu = _vec(linres.u)
6354
veck₁ = _vec(k₁)
@@ -160,16 +151,7 @@ end
160151
integrator.opts.abstol, integrator.opts.reltol,
161152
integrator.opts.internalnorm, t)
162153

163-
if repeat_step
164-
linres = dolinsolve(
165-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
166-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
167-
solverdata = (; gamma = γ))
168-
else
169-
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
170-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
171-
solverdata = (; gamma = γ))
172-
end
154+
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))
173155

174156
vecu = _vec(linres.u)
175157
veck₁ = _vec(k₁)
@@ -517,16 +499,7 @@ end
517499
integrator.opts.abstol, integrator.opts.reltol,
518500
integrator.opts.internalnorm, t)
519501

520-
if repeat_step
521-
linres = dolinsolve(
522-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
523-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
524-
solverdata = (; gamma = dtgamma))
525-
else
526-
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
527-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
528-
solverdata = (; gamma = dtgamma))
529-
end
502+
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))
530503

531504
vecu = _vec(linres.u)
532505
veck1 = _vec(k1)
@@ -712,16 +685,7 @@ end
712685
integrator.opts.abstol, integrator.opts.reltol,
713686
integrator.opts.internalnorm, t)
714687

715-
if repeat_step
716-
linres = dolinsolve(
717-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
718-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
719-
solverdata = (; gamma = dtgamma))
720-
else
721-
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
722-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
723-
solverdata = (; gamma = dtgamma))
724-
end
688+
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))
725689

726690
vecu = _vec(linres.u)
727691
veck1 = _vec(k1)
@@ -1020,16 +984,7 @@ end
1020984
integrator.opts.abstol, integrator.opts.reltol,
1021985
integrator.opts.internalnorm, t)
1022986

1023-
if repeat_step
1024-
linres = dolinsolve(
1025-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
1026-
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
1027-
solverdata = (; gamma = dtgamma))
1028-
else
1029-
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
1030-
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
1031-
solverdata = (; gamma = dtgamma))
1032-
end
987+
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))
1033988

1034989
@.. broadcast=false $(_vec(k1))=-linres.u
1035990

@@ -1383,16 +1338,7 @@ end
13831338
integrator.opts.abstol, integrator.opts.reltol,
13841339
integrator.opts.internalnorm, t)
13851340

1386-
if repeat_step
1387-
linres = dolinsolve(
1388-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
1389-
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
1390-
solverdata = (; gamma = dtgamma))
1391-
else
1392-
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
1393-
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
1394-
solverdata = (; gamma = dtgamma))
1395-
end
1341+
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))
13961342

13971343
@.. broadcast=false $(_vec(k1))=-linres.u
13981344

@@ -1786,16 +1732,7 @@ end
17861732
integrator.opts.abstol, integrator.opts.reltol,
17871733
integrator.opts.internalnorm, t)
17881734

1789-
if repeat_step
1790-
linres = dolinsolve(
1791-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
1792-
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
1793-
solverdata = (; gamma = dtgamma))
1794-
else
1795-
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
1796-
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
1797-
solverdata = (; gamma = dtgamma))
1798-
end
1735+
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))
17991736

18001737
vecu = _vec(linres.u)
18011738
veck1 = _vec(k1)

0 commit comments

Comments
 (0)