Skip to content

Commit 671775f

Browse files
committed
simplify dolinsolve
1 parent 94a6fbc commit 671775f

39 files changed

+392
-878
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ OrdinaryDiffEqQPRK = "1"
121121
OrdinaryDiffEqRKN = "1"
122122
OrdinaryDiffEqRosenbrock = "1"
123123
OrdinaryDiffEqSDIRK = "1"
124-
OrdinaryDiffEqStabilizedIRK = "1"
125124
OrdinaryDiffEqSSPRK = "1"
125+
OrdinaryDiffEqStabilizedIRK = "1"
126126
OrdinaryDiffEqStabilizedRK = "1"
127127
OrdinaryDiffEqSymplecticRK = "1"
128128
OrdinaryDiffEqTsit5 = "1"

lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!,
77
OrdinaryDiffEqMutableCache, OrdinaryDiffEqConstantCache,
88
OrdinaryDiffEqNewtonAdaptiveAlgorithm,
99
OrdinaryDiffEqNewtonAlgorithm,
10-
AbstractController, DEFAULT_PRECS,
11-
CompiledFloats, uses_uprev,
10+
AbstractController, CompiledFloats, uses_uprev,
1211
alg_cache, _vec, _reshape, @cache,
1312
isfsal, full_cache,
1413
constvalue, isadaptive, error_constant,

lib/OrdinaryDiffEqBDF/src/algorithms.jl

Lines changed: 42 additions & 93 deletions
Large diffs are not rendered by default.

lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,10 @@ end
105105
Divergence = -2
106106
end
107107
const TryAgain = SlowConvergence
108-
109-
DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, solverdata) = nothing, nothing
110108
isdiscretecache(cache) = false
111109

110+
# unused. Delete this once StocasticDiffEq doesn't use it
111+
DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, solverdata) = nothing, nothing
112112
include("doc_utils.jl")
113113
include("misc_utils.jl")
114114

lib/OrdinaryDiffEqCore/src/doc_utils.jl

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ function differentiation_rk_docstring(description::String,
8787
concrete_jac = nothing,
8888
diff_type = Val{:forward},
8989
linsolve = nothing,
90-
precs = DEFAULT_PRECS,
9190
""" * extra_keyword_default
9291

9392
keyword_default_description = """
@@ -111,40 +110,7 @@ function differentiation_rk_docstring(description::String,
111110
For example, to use [KLU.jl](https://github.com/JuliaSparse/KLU.jl), specify
112111
`$name(linsolve = KLUFactorization()`).
113112
When `nothing` is passed, uses `DefaultLinearSolver`.
114-
- `precs`: Any [LinearSolve.jl-compatible preconditioner](https://docs.sciml.ai/LinearSolve/stable/basics/Preconditioners/)
115-
can be used as a left or right preconditioner.
116-
Preconditioners are specified by the `Pl,Pr = precs(W,du,u,p,t,newW,Plprev,Prprev,solverdata)`
117-
function where the arguments are defined as:
118-
- `W`: the current Jacobian of the nonlinear system. Specified as either
119-
``I - \\gamma J`` or ``I/\\gamma - J`` depending on the algorithm. This will
120-
commonly be a `WOperator` type defined by OrdinaryDiffEq.jl. It is a lazy
121-
representation of the operator. Users can construct the W-matrix on demand
122-
by calling `convert(AbstractMatrix,W)` to receive an `AbstractMatrix` matching
123-
the `jac_prototype`.
124-
- `du`: the current ODE derivative
125-
- `u`: the current ODE state
126-
- `p`: the ODE parameters
127-
- `t`: the current ODE time
128-
- `newW`: a `Bool` which specifies whether the `W` matrix has been updated since
129-
the last call to `precs`. It is recommended that this is checked to only
130-
update the preconditioner when `newW == true`.
131-
- `Plprev`: the previous `Pl`.
132-
- `Prprev`: the previous `Pr`.
133-
- `solverdata`: Optional extra data the solvers can give to the `precs` function.
134-
Solver-dependent and subject to change.
135-
The return is a tuple `(Pl,Pr)` of the LinearSolve.jl-compatible preconditioners.
136-
To specify one-sided preconditioning, simply return `nothing` for the preconditioner
137-
which is not used. Additionally, `precs` must supply the dispatch:
138-
```julia
139-
Pl, Pr = precs(W, du, u, p, t, ::Nothing, ::Nothing, ::Nothing, solverdata)
140-
```
141-
which is used in the solver setup phase to construct the integrator
142-
type with the preconditioners `(Pl,Pr)`.
143-
The default is `precs=DEFAULT_PRECS` where the default preconditioner function
144-
is defined as:
145-
```julia
146-
DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, solverdata) = nothing, nothing
147-
```
113+
148114
""" * extra_keyword_description
149115

150116
generic_solver_docstring(

lib/OrdinaryDiffEqDefault/test/default_solver_tests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ rosensol = solve(prob_rober, AutoTsit5(Rosenbrock23(autodiff = false)))
4949
sol = solve(prob_rober, reltol = 1e-7, abstol = 1e-7)
5050
rosensol = solve(
5151
prob_rober, AutoVern7(Rodas5P(autodiff = false)), reltol = 1e-7, abstol = 1e-7)
52-
# test that default has the same performance as AutoTsit5(Rosenbrock23()) (which we expect it to use for this).
52+
# test that default has the same performance as AutoTsit5(Rodas5P()) (which we expect it to use for this).
5353
@test sol.stats.naccept == rosensol.stats.naccept
5454
@test sol.stats.nf == rosensol.stats.nf
5555
@test unique(sol.alg_choice) == [2, 4]
@@ -75,7 +75,7 @@ for n in (100, 600)
7575
vcat([1.0, 0.0, 0.0], ones(n)), (0.0, 100.0), (0.04, 3e7, 1e4))
7676
global sol = solve(prob_ex_rober)
7777
fsol = solve(prob_ex_rober, AutoTsit5(FBDF(; autodiff = false, linsolve)))
78-
# test that default has the same performance as AutoTsit5(Rosenbrock23()) (which we expect it to use for this).
78+
# test that default has the same performance as AutoTsit5(FBDF()) (which we expect it to use for this).
7979
@test sol.stats.naccept == fsol.stats.naccept
8080
@test sol.stats.nf == fsol.stats.nf
8181
@test unique(sol.alg_choice) == [1, stiffalg]

lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, S
2727
using DiffEqBase: TimeGradientWrapper,
2828
UJacobianWrapper, TimeDerivativeWrapper,
2929
UDerivativeWrapper
30-
using SciMLBase: AbstractSciMLOperator
30+
using SciMLBase: AbstractSciMLOperator, DEIntegrator
3131
import OrdinaryDiffEqCore
3232
using OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, OrdinaryDiffEqAdaptiveImplicitAlgorithm,
3333
DAEAlgorithm,

lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,7 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits},
883883
# Thus setup JacVec and a concrete J, using sparsity when possible
884884
_f = islin ? (isode ? f.f : f.f1.f) : f
885885
J = if f.jac_prototype === nothing
886-
ArrayInterface.undefmatrix(u)
886+
ArrayInterface.zeromatrix(u)
887887
else
888888
deepcopy(f.jac_prototype)
889889
end
@@ -907,7 +907,7 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits},
907907
f.jac(uprev, p, t)
908908
end
909909
elseif f.jac_prototype === nothing
910-
ArrayInterface.undefmatrix(u)
910+
ArrayInterface.zeromatrix(u)
911911
else
912912
deepcopy(f.jac_prototype)
913913
end
@@ -1003,4 +1003,4 @@ function resize_J_W!(cache, integrator, i)
10031003
end
10041004

10051005
nothing
1006-
end
1006+
end

lib/OrdinaryDiffEqDifferentiation/src/linsolve_utils.jl

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,20 @@ issuccess_W(W::Number) = !iszero(W)
33
issuccess_W(::Any) = true
44

55
function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothing,
6-
du = nothing, u = nothing, p = nothing, t = nothing,
7-
weight = nothing, solverdata = nothing,
86
reltol = integrator === nothing ? nothing : integrator.opts.reltol)
9-
A !== nothing && (linsolve.A = A)
107
b !== nothing && (linsolve.b = b)
118
linu !== nothing && (linsolve.u = linu)
129

13-
Plprev = linsolve.Pl isa LinearSolve.ComposePreconditioner ? linsolve.Pl.outer :
14-
linsolve.Pl
15-
Prprev = linsolve.Pr isa LinearSolve.ComposePreconditioner ? linsolve.Pr.outer :
16-
linsolve.Pr
17-
1810
_alg = unwrap_alg(integrator, true)
19-
20-
_Pl, _Pr = _alg.precs(linsolve.A, du, u, p, t, A !== nothing, Plprev, Prprev,
21-
solverdata)
22-
if (_Pl !== nothing || _Pr !== nothing)
23-
__Pl = _Pl === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pl
24-
__Pr = _Pr === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pr
25-
linsolve.Pl = __Pl
26-
linsolve.Pr = __Pr
11+
if !isnothing(A)
12+
if integrator isa DEIntegrator
13+
(;u, p, t) = integrator
14+
du = hasproperty(integrator, :du) ? integrator.du : nothing
15+
p = (du, u, p, t)
16+
reinit!(linsolve; A, p)
17+
else
18+
reinit!(linsolve; A)
19+
end
2720
end
2821

2922
linres = solve!(linsolve; reltol)
@@ -44,16 +37,15 @@ function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothi
4437
return linres
4538
end
4639

47-
function wrapprecs(_Pl::Nothing, _Pr::Nothing, weight, u)
48-
Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight)))
49-
Pr = Diagonal(_vec(weight))
50-
Pl, Pr
51-
end
52-
53-
function wrapprecs(_Pl, _Pr, weight, u)
54-
Pl = _Pl === nothing ? SciMLOperators.IdentityOperator(length(u)) : _Pl
55-
Pr = _Pr === nothing ? SciMLOperators.IdentityOperator(length(u)) : _Pr
56-
Pl, Pr
40+
function wrapprecs(linsolver, W, weight)
41+
if hasproperty(linsolver, :precs) && isnothing(linsolver.precs)
42+
Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight)))
43+
Pr = Diagonal(_vec(weight))
44+
precs = Returns((Pl, Pr))
45+
return remake(linsolver; precs)
46+
else
47+
return linsolver
48+
end
5749
end
5850

5951
Base.resize!(p::LinearSolve.LinearCache, i) = p

lib/OrdinaryDiffEqExtrapolation/src/OrdinaryDiffEqExtrapolation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import OrdinaryDiffEqCore: alg_order, alg_maximum_order, get_current_adaptive_or
1313
OrdinaryDiffEqAdaptiveAlgorithm,
1414
OrdinaryDiffEqAdaptiveImplicitAlgorithm,
1515
alg_cache, CompiledFloats, @threaded, stepsize_controller!,
16-
DEFAULT_PRECS, full_cache,
16+
full_cache,
1717
constvalue, PolyesterThreads, Sequential, BaseThreads,
1818
_digest_beta1_beta2, timedepentdtmin, _unwrap_val,
1919
_reshape, _vec, get_fsalfirstlast, generic_solver_docstring,

0 commit comments

Comments
 (0)