Skip to content

Commit e2914bb

Browse files
authored
refactor!: move preconditioners inside linear solvers (#485)
* refactor!: move preconditioners inside linear solvers * test: precs need concrete jacobian * docs: fix references
1 parent fcee7a1 commit e2914bb

File tree

18 files changed

+113
-148
lines changed

18 files changed

+113
-148
lines changed

docs/src/native/solvers.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@ documentation.
1818
uses the LinearSolve.jl default algorithm choice. For more information on available
1919
algorithm choices, see the
2020
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
21-
- `precs`: the choice of preconditioners for the linear solver. Defaults to using no
22-
preconditioners. For more information on specifying preconditioners for LinearSolve
23-
algorithms, consult the
24-
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
2521
- `linesearch`: the line search algorithm to use. Defaults to
2622
[`NoLineSearch()`](@extref LineSearch.NoLineSearch), which means that no line search is
2723
performed.

docs/src/release_notes.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
### Breaking Changes in `NonlinearSolve.jl` v4
66

77
- `ApproximateJacobianSolveAlgorithm` has been renamed to `QuasiNewtonAlgorithm`.
8+
- Preconditioners for the linear solver needs to be specified with the linear solver
9+
instead of `precs` keyword argument.
810
- See [common breaking changes](@ref common-breaking-changes-v4v2) below.
911

1012
### Breaking Changes in `SimpleNonlinearSolve.jl` v2

docs/src/tutorials/large_systems.md

Lines changed: 29 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ end
100100
101101
u0 = init_brusselator_2d(xyd_brusselator)
102102
prob_brusselator_2d = NonlinearProblem(
103-
brusselator_2d_loop, u0, p; abstol = 1e-10, reltol = 1e-10)
103+
brusselator_2d_loop, u0, p; abstol = 1e-10, reltol = 1e-10
104+
)
104105
```
105106

106107
## Choosing Jacobian Types
@@ -140,7 +141,8 @@ using SparseConnectivityTracer
140141
141142
prob_brusselator_2d_autosparse = NonlinearProblem(
142143
NonlinearFunction(brusselator_2d_loop; sparsity = TracerSparsityDetector()),
143-
u0, p; abstol = 1e-10, reltol = 1e-10)
144+
u0, p; abstol = 1e-10, reltol = 1e-10
145+
)
144146
145147
@btime solve(prob_brusselator_2d_autosparse,
146148
NewtonRaphson(; autodiff = AutoForwardDiff(; chunksize = 12)));
@@ -235,7 +237,7 @@ choices, see the
235237

236238
Any [LinearSolve.jl-compatible preconditioner](https://docs.sciml.ai/LinearSolve/stable/basics/Preconditioners/)
237239
can be used as a preconditioner in the linear solver interface. To define preconditioners,
238-
one must define a `precs` function in compatible with nonlinear solvers which returns the
240+
one must define a `precs` function in compatible with linear solvers which returns the
239241
left and right preconditioners, matrices which approximate the inverse of `W = I - gamma*J`
240242
used in the solution of the ODE. An example of this with using
241243
[IncompleteLU.jl](https://github.com/haampie/IncompleteLU.jl) is as follows:
@@ -244,26 +246,18 @@ used in the solution of the ODE. An example of this with using
244246
# FIXME: On 1.10+ this is broken. Skipping this for now.
245247
using IncompleteLU
246248

247-
function incompletelu(W, du, u, p, t, newW, Plprev, Prprev, solverdata)
248-
if newW === nothing || newW
249-
Pl = ilu(W, τ = 50.0)
250-
else
251-
Pl = Plprev
252-
end
253-
Pl, nothing
254-
end
249+
incompletelu(W, p = nothing) = ilu(W, τ = 50.0), LinearAlgebra.I
255250

256251
@btime solve(prob_brusselator_2d_sparse,
257-
NewtonRaphson(linsolve = KrylovJL_GMRES(), precs = incompletelu, concrete_jac = true));
252+
NewtonRaphson(linsolve = KrylovJL_GMRES(precs = incompletelu), concrete_jac = true)
253+
);
258254
nothing # hide
259255
```
260256

261257
Notice a few things about this preconditioner. This preconditioner uses the sparse Jacobian,
262258
and thus we set `concrete_jac = true` to tell the algorithm to generate the Jacobian
263-
(otherwise, a Jacobian-free algorithm is used with GMRES by default). Then `newW = true`
264-
whenever a new `W` matrix is computed, and `newW = nothing` during the startup phase of the
265-
solver. Thus, we do a check `newW === nothing || newW` and when true, it's only at these
266-
points when we update the preconditioner, otherwise we just pass on the previous version.
259+
(otherwise, a Jacobian-free algorithm is used with GMRES by default).
260+
267261
We use `convert(AbstractMatrix,W)` to get the concrete `W` matrix (matching `jac_prototype`,
268262
thus `SpraseMatrixCSC`) which we can use in the preconditioner's definition. Then we use
269263
`IncompleteLU.ilu` on that sparse matrix to generate the preconditioner. We return
@@ -279,39 +273,36 @@ which is more automatic. The setup is very similar to before:
279273
```@example ill_conditioned_nlprob
280274
using AlgebraicMultigrid
281275
282-
function algebraicmultigrid(W, du, u, p, t, newW, Plprev, Prprev, solverdata)
283-
if newW === nothing || newW
284-
Pl = aspreconditioner(ruge_stuben(convert(AbstractMatrix, W)))
285-
else
286-
Pl = Plprev
287-
end
288-
Pl, nothing
276+
function algebraicmultigrid(W, p = nothing)
277+
return aspreconditioner(ruge_stuben(convert(AbstractMatrix, W))), LinearAlgebra.I
289278
end
290279
291280
@btime solve(prob_brusselator_2d_sparse,
292-
NewtonRaphson(linsolve = KrylovJL_GMRES(), precs = algebraicmultigrid,
293-
concrete_jac = true));
281+
NewtonRaphson(
282+
linsolve = KrylovJL_GMRES(; precs = algebraicmultigrid), concrete_jac = true
283+
)
284+
);
294285
nothing # hide
295286
```
296287

297288
or with a Jacobi smoother:
298289

299290
```@example ill_conditioned_nlprob
300-
function algebraicmultigrid2(W, du, u, p, t, newW, Plprev, Prprev, solverdata)
301-
if newW === nothing || newW
302-
A = convert(AbstractMatrix, W)
303-
Pl = AlgebraicMultigrid.aspreconditioner(AlgebraicMultigrid.ruge_stuben(
304-
A, presmoother = AlgebraicMultigrid.Jacobi(rand(size(A, 1))),
305-
postsmoother = AlgebraicMultigrid.Jacobi(rand(size(A, 1)))))
306-
else
307-
Pl = Plprev
308-
end
309-
Pl, nothing
291+
function algebraicmultigrid2(W, p = nothing)
292+
A = convert(AbstractMatrix, W)
293+
Pl = AlgebraicMultigrid.aspreconditioner(AlgebraicMultigrid.ruge_stuben(
294+
A, presmoother = AlgebraicMultigrid.Jacobi(rand(size(A, 1))),
295+
postsmoother = AlgebraicMultigrid.Jacobi(rand(size(A, 1)))
296+
))
297+
return Pl, LinearAlgebra.I
310298
end
311299
312-
@btime solve(prob_brusselator_2d_sparse,
313-
NewtonRaphson(linsolve = KrylovJL_GMRES(), precs = algebraicmultigrid2,
314-
concrete_jac = true));
300+
@btime solve(
301+
prob_brusselator_2d_sparse,
302+
NewtonRaphson(
303+
linsolve = KrylovJL_GMRES(precs = algebraicmultigrid2), concrete_jac = true
304+
)
305+
);
315306
nothing # hide
316307
```
317308

lib/NonlinearSolveBase/ext/NonlinearSolveBaseLinearSolveExt.jl

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,33 +11,15 @@ using LinearAlgebra: ColumnNorm
1111
using NonlinearSolveBase: NonlinearSolveBase, LinearSolveJLCache, LinearSolveResult, Utils
1212

1313
function (cache::LinearSolveJLCache)(;
14-
A = nothing, b = nothing, linu = nothing, du = nothing, p = nothing,
15-
cachedata = nothing, reuse_A_if_factorization = false, verbose = true, kwargs...
14+
A = nothing, b = nothing, linu = nothing,
15+
reuse_A_if_factorization = false, verbose = true, kwargs...
1616
)
1717
cache.stats.nsolve += 1
1818

1919
update_A!(cache, A, reuse_A_if_factorization)
2020
b !== nothing && setproperty!(cache.lincache, :b, b)
2121
linu !== nothing && NonlinearSolveBase.set_lincache_u!(cache, linu)
2222

23-
Plprev = cache.lincache.Pl
24-
Prprev = cache.lincache.Pr
25-
26-
if cache.precs === nothing
27-
Pl, Pr = nothing, nothing
28-
else
29-
Pl, Pr = cache.precs(
30-
cache.lincache.A, du, linu, p, nothing,
31-
A !== nothing, Plprev, Prprev, cachedata
32-
)
33-
end
34-
35-
if Pl !== nothing || Pr !== nothing
36-
Pl, Pr = NonlinearSolveBase.wrap_preconditioners(Pl, Pr, linu)
37-
cache.lincache.Pl = Pl
38-
cache.lincache.Pr = Pr
39-
end
40-
4123
linres = solve!(cache.lincache)
4224
cache.lincache = linres.cache
4325
# Unfortunately LinearSolve.jl doesn't have the most uniform ReturnCode handling
@@ -58,7 +40,8 @@ function (cache::LinearSolveJLCache)(;
5840
linprob = LinearProblem(A, b; u0 = linres.u)
5941
cache.additional_lincache = init(
6042
linprob, QRFactorization(ColumnNorm()); alias_u0 = false,
61-
alias_A = false, alias_b = false, cache.lincache.Pl, cache.lincache.Pr)
43+
alias_A = false, alias_b = false
44+
)
6245
else
6346
cache.additional_lincache.A = A
6447
cache.additional_lincache.b = b

lib/NonlinearSolveBase/src/descent/damped_newton.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""
2-
DampedNewtonDescent(;
3-
linsolve = nothing, precs = nothing, initial_damping, damping_fn
4-
)
2+
DampedNewtonDescent(; linsolve = nothing, initial_damping, damping_fn)
53
64
A Newton descent algorithm with damping. The damping factor is computed using the
75
`damping_fn` function. The descent direction is computed as ``(JᵀJ + λDᵀD) δu = -fu``. For
@@ -20,7 +18,6 @@ The damping factor returned must be a non-negative number.
2018
"""
2119
@kwdef @concrete struct DampedNewtonDescent <: AbstractDescentDirection
2220
linsolve = nothing
23-
precs = nothing
2421
initial_damping
2522
damping_fn <: AbstractDampingFunction
2623
end

lib/NonlinearSolveBase/src/descent/dogleg.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Dogleg(; linsolve = nothing, precs = nothing)
2+
Dogleg(; linsolve = nothing)
33
44
Switch between Newton's method and the steepest descent method depending on the size of the
55
trust region. The trust region is specified via keyword argument `trust_region` to
@@ -15,18 +15,18 @@ end
1515
supports_trust_region(::Dogleg) = true
1616
get_linear_solver(alg::Dogleg) = get_linear_solver(alg.newton_descent)
1717

18-
function Dogleg(; linsolve = nothing, precs = nothing, damping = Val(false),
18+
function Dogleg(; linsolve = nothing, damping = Val(false),
1919
damping_fn = missing, initial_damping = missing, kwargs...)
2020
if !Utils.unwrap_val(damping)
21-
return Dogleg(NewtonDescent(; linsolve, precs), SteepestDescent(; linsolve, precs))
21+
return Dogleg(NewtonDescent(; linsolve), SteepestDescent(; linsolve))
2222
end
2323
if damping_fn === missing || initial_damping === missing
2424
throw(ArgumentError("`damping_fn` and `initial_damping` must be supplied if \
2525
`damping = Val(true)`."))
2626
end
2727
return Dogleg(
28-
DampedNewtonDescent(; linsolve, precs, damping_fn, initial_damping),
29-
SteepestDescent(; linsolve, precs)
28+
DampedNewtonDescent(; linsolve, damping_fn, initial_damping),
29+
SteepestDescent(; linsolve)
3030
)
3131
end
3232

lib/NonlinearSolveBase/src/descent/newton.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
NewtonDescent(; linsolve = nothing, precs = nothing)
2+
NewtonDescent(; linsolve = nothing)
33
44
Compute the descent direction as ``J δu = -fu``. For non-square Jacobian problems, this is
55
commonly referred to as the Gauss-Newton Descent.
@@ -8,7 +8,6 @@ See also [`Dogleg`](@ref), [`SteepestDescent`](@ref), [`DampedNewtonDescent`](@r
88
"""
99
@kwdef @concrete struct NewtonDescent <: AbstractDescentDirection
1010
linsolve = nothing
11-
precs = nothing
1211
end
1312

1413
supports_line_search(::NewtonDescent) = true
@@ -103,15 +102,15 @@ function InternalAPI.solve!(
103102
@static_timeit cache.timer "linear solve" begin
104103
linres = cache.lincache(;
105104
A = Utils.maybe_symmetric(cache.JᵀJ_cache), b = cache.Jᵀfu_cache,
106-
kwargs..., linu = Utils.safe_vec(δu), du = Utils.safe_vec(δu),
105+
kwargs..., linu = Utils.safe_vec(δu),
107106
reuse_A_if_factorization = !new_jacobian || (idx !== Val(1))
108107
)
109108
end
110109
else
111110
@static_timeit cache.timer "linear solve" begin
112111
linres = cache.lincache(;
113112
A = J, b = Utils.safe_vec(fu),
114-
kwargs..., linu = Utils.safe_vec(δu), du = Utils.safe_vec(δu),
113+
kwargs..., linu = Utils.safe_vec(δu),
115114
reuse_A_if_factorization = !new_jacobian || idx !== Val(1)
116115
)
117116
end

lib/NonlinearSolveBase/src/descent/steepest.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
SteepestDescent(; linsolve = nothing, precs = nothing)
2+
SteepestDescent(; linsolve = nothing)
33
44
Compute the descent direction as ``δu = -Jᵀfu``. The linear solver and preconditioner are
55
only used if `J` is provided in the inverted form.
@@ -8,7 +8,6 @@ See also [`Dogleg`](@ref), [`NewtonDescent`](@ref), [`DampedNewtonDescent`](@ref
88
"""
99
@kwdef @concrete struct SteepestDescent <: AbstractDescentDirection
1010
linsolve = nothing
11-
precs = nothing
1211
end
1312

1413
supports_line_search(::SteepestDescent) = true
@@ -57,7 +56,6 @@ function InternalAPI.solve!(
5756
A = J === nothing ? nothing : transpose(J)
5857
linres = cache.lincache(;
5958
A, b = Utils.safe_vec(fu), kwargs..., linu = Utils.safe_vec(δu),
60-
du = Utils.safe_vec(δu),
6159
reuse_A_if_factorization = !new_jacobian || idx !== Val(1)
6260
)
6361
δu = Utils.restructure(SciMLBase.get_du(cache, idx), linres.u)

lib/NonlinearSolveBase/src/jacobian.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ Construct a cache for the Jacobian of `f` w.r.t. `u`.
2929
- `jvp_autodiff`: Automatic Differentiation or Finite Differencing backend for computing
3030
the Jacobian-vector product.
3131
- `linsolve`: Linear Solver Algorithm used to determine if we need a concrete jacobian
32-
or if possible we can just use a [`SciMLJacobianOperators.JacobianOperator`](@ref)
33-
instead.
32+
or if possible we can just use a `JacobianOperator` instead.
3433
"""
3534
function construct_jacobian_cache(
3635
prob, alg, f::NonlinearFunction, fu, u = prob.u0, p = prob.p; stats,

lib/NonlinearSolveBase/src/linear_solve.jl

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ end
77
lincache
88
linsolve
99
additional_lincache::Any
10-
precs
1110
stats::NLStats
1211
end
1312

@@ -34,8 +33,8 @@ handled:
3433
3534
```julia
3635
(cache::LinearSolverCache)(;
37-
A = nothing, b = nothing, linu = nothing, du = nothing, p = nothing,
38-
weight = nothing, cachedata = nothing, reuse_A_if_factorization = false, kwargs...)
36+
A = nothing, b = nothing, linu = nothing, reuse_A_if_factorization = false, kwargs...
37+
)
3938
```
4039
4140
Returns the solution of the system `u` and stores the updated cache in `cache.lincache`.
@@ -60,15 +59,11 @@ aliasing arguments even after cache construction, i.e., if we passed in an `A` t
6059
not mutated, we do this by copying over `A` to a preconstructed cache.
6160
"""
6261
function construct_linear_solver(alg, linsolve, A, b, u; stats, kwargs...)
63-
no_preconditioner = !hasfield(typeof(alg), :precs) || alg.precs === nothing
64-
6562
if (A isa Number && b isa Number) || (A isa Diagonal)
6663
return NativeJLLinearSolveCache(A, b, stats)
6764
elseif linsolve isa typeof(\)
68-
!no_preconditioner &&
69-
error("Default Julia Backsolve Operator `\\` doesn't support Preconditioners")
7065
return NativeJLLinearSolveCache(A, b, stats)
71-
elseif no_preconditioner && linsolve === nothing
66+
elseif linsolve === nothing
7267
if (A isa SMatrix || A isa WrappedArray{<:Any, <:SMatrix})
7368
return NativeJLLinearSolveCache(A, b, stats)
7469
end
@@ -78,17 +73,9 @@ function construct_linear_solver(alg, linsolve, A, b, u; stats, kwargs...)
7873
@bb u_cache = copy(u_fixed)
7974
linprob = LinearProblem(A, b; u0 = u_cache, kwargs...)
8075

81-
if no_preconditioner
82-
precs, Pl, Pr = nothing, nothing, nothing
83-
else
84-
precs = alg.precs
85-
Pl, Pr = precs(A, nothing, u, ntuple(Returns(nothing), 6)...)
86-
end
87-
Pl, Pr = wrap_preconditioners(Pl, Pr, u)
88-
8976
# unlias here, we will later use these as caches
90-
lincache = init(linprob, linsolve; alias_A = false, alias_b = false, Pl, Pr)
91-
return LinearSolveJLCache(lincache, linsolve, nothing, precs, stats)
77+
lincache = init(linprob, linsolve; alias_A = false, alias_b = false)
78+
return LinearSolveJLCache(lincache, linsolve, nothing, stats)
9279
end
9380

9481
function (cache::NativeJLLinearSolveCache)(;

0 commit comments

Comments
 (0)