Skip to content

Commit 7450590

Browse files
committed
Integrate the linear solve better into the algorithm
1 parent dd4a2fb commit 7450590

File tree

7 files changed

+85
-18
lines changed

7 files changed

+85
-18
lines changed

src/core/approximate_jacobian.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,27 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
291291
cache.descent_cache, J, cache.fu, cache.u; new_jacobian, cache.kwargs...)
292292
end
293293
end
294+
295+
if !descent_result.linsolve_success
296+
if new_jacobian && cache.steps_since_last_reset == 0
297+
# Extremely pathological case. Jacobian was just reset and linear solve
298+
# failed. Should ideally never happen in practice unless true jacobian init
299+
# is used.
300+
cache.retcode = LinearSolveFailureCode
301+
cache.force_stop = true
302+
return
303+
else
304+
# Force a reinit because the problem is currently un-solvable
305+
if !haskey(cache.kwargs, :verbose) || cache.kwargs[:verbose]
306+
@warn "Linear Solve Failed but Jacobian Information is not current. \
307+
Retrying with reinitialized Approximate Jacobian."
308+
end
309+
cache.force_reinit = true
310+
__step!(cache; recompute_jacobian = true)
311+
return
312+
end
313+
end
314+
294315
δu, descent_intermediates = descent_result.δu, descent_result.extras
295316

296317
if descent_result.success

src/core/generalized_first_order.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,27 @@ function __step!(cache::GeneralizedFirstOrderAlgorithmCache{iip, GB};
230230
cache.descent_cache, J, cache.fu, cache.u; new_jacobian, cache.kwargs...)
231231
end
232232
end
233+
234+
if !descent_result.linsolve_success
235+
if new_jacobian
236+
# Jacobian Information is current and linear solve failed terminate the solve
237+
cache.retcode = LinearSolveFailureCode
238+
cache.force_stop = true
239+
return
240+
else
241+
# Jacobian Information is not current and linear solve failed, recompute
242+
# Jacobian
243+
if !haskey(cache.kwargs, :verbose) || cache.kwargs[:verbose]
244+
@warn "Linear Solve Failed but Jacobian Information is not current. \
245+
Retrying with updated Jacobian."
246+
end
247+
# In the 2nd call the `new_jacobian` is guaranteed to be `true`.
248+
cache.make_new_jacobian = true
249+
__step!(cache; recompute_jacobian = true, kwargs...)
250+
return
251+
end
252+
end
253+
233254
δu, descent_intermediates = descent_result.δu, descent_result.extras
234255

235256
if descent_result.success

src/descent/common.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""
2-
DescentResult(; δu = missing, u = missing, success::Bool = true, extras = (;))
2+
DescentResult(; δu = missing, u = missing, success::Bool = true,
3+
linsolve_success::Bool = true, extras = (;))
34
45
Construct a `DescentResult` object.
56
@@ -9,6 +10,7 @@ Construct a `DescentResult` object.
910
- `u`: The new iterate. This is provided only for multi-step methods currently.
1011
- `success`: Certain Descent Algorithms can reject a descent direction for example
1112
[`GeodesicAcceleration`](@ref).
13+
- `linsolve_success`: Whether the line search was successful.
1214
- `extras`: A named tuple containing intermediates computed during the solve.
1315
For example, [`GeodesicAcceleration`](@ref) returns `NamedTuple{(:v, :a)}` containing
1416
the "velocity" and "acceleration" terms.
@@ -17,10 +19,12 @@ Construct a `DescentResult` object.
1719
δu
1820
u
1921
success::Bool
22+
linsolve_success::Bool
2023
extras
2124
end
2225

23-
function DescentResult(; δu = missing, u = missing, success::Bool = true, extras = (;))
26+
function DescentResult(; δu = missing, u = missing, success::Bool = true,
27+
linsolve_success::Bool = true, extras = (;))
2428
@assert δu !== missing || u !== missing
25-
return DescentResult(δu, u, success, extras)
29+
return DescentResult(δu, u, success, linsolve_success, extras)
2630
end

src/descent/damped_newton.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,14 @@ function __internal_solve!(cache::DampedNewtonDescentCache{INV, mode}, J, fu,
200200
end
201201

202202
@static_timeit cache.timer "linear solve" begin
203-
δu = cache.lincache(;
203+
linres = cache.lincache(;
204204
A, b, reuse_A_if_factorization = !new_jacobian && !recompute_A,
205205
kwargs..., linu = _vec(δu))
206-
δu = _restructure(get_du(cache, idx), δu)
206+
δu = _restructure(get_du(cache, idx), linres.u)
207+
if !linres.success
208+
set_du!(cache, δu, idx)
209+
return DescentResult(; δu, success = false, linsolve_success = false)
210+
end
207211
end
208212

209213
@bb @. δu *= -1

src/descent/newton.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,14 @@ function __internal_solve!(
8181
@bb δu = J × vec(fu)
8282
else
8383
@static_timeit cache.timer "linear solve" begin
84-
δu = cache.lincache(;
84+
linres = cache.lincache(;
8585
A = J, b = _vec(fu), kwargs..., linu = _vec(δu), du = _vec(δu),
8686
reuse_A_if_factorization = !new_jacobian || (idx !== Val(1)))
87-
δu = _restructure(get_du(cache, idx), δu)
87+
δu = _restructure(get_du(cache, idx), linres.u)
88+
if !linres.success
89+
set_du!(cache, δu, idx)
90+
return DescentResult(; δu, success = false, linsolve_success = false)
91+
end
8892
end
8993
end
9094
@bb @. δu *= -1
@@ -102,10 +106,15 @@ function __internal_solve!(
102106
end
103107
@bb cache.Jᵀfu_cache = transpose(J) × vec(fu)
104108
@static_timeit cache.timer "linear solve" begin
105-
δu = cache.lincache(; A = __maybe_symmetric(cache.JᵀJ_cache), b = cache.Jᵀfu_cache,
109+
linres = cache.lincache(;
110+
A = __maybe_symmetric(cache.JᵀJ_cache), b = cache.Jᵀfu_cache,
106111
kwargs..., linu = _vec(δu), du = _vec(δu),
107112
reuse_A_if_factorization = !new_jacobian || (idx !== Val(1)))
108-
δu = _restructure(get_du(cache, idx), δu)
113+
δu = _restructure(get_du(cache, idx), linres.u)
114+
if !linres.success
115+
set_du!(cache, δu, idx)
116+
return DescentResult(; δu, success = false, linsolve_success = false)
117+
end
109118
end
110119
@bb @. δu *= -1
111120
set_du!(cache, δu, idx)

src/descent/steepest.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,14 @@ function __internal_solve!(cache::SteepestDescentCache{INV}, J, fu, u, idx::Val
5454
if INV
5555
A = J === nothing ? nothing : transpose(J)
5656
@static_timeit cache.timer "linear solve" begin
57-
δu = cache.lincache(;
57+
linres = cache.lincache(;
5858
A, b = _vec(fu), kwargs..., linu = _vec(δu), du = _vec(δu),
5959
reuse_A_if_factorization = !new_jacobian || idx !== Val(1))
60-
δu = _restructure(get_du(cache, idx), δu)
60+
δu = _restructure(get_du(cache, idx), linres.u)
61+
if !linres.success
62+
set_du!(cache, δu, idx)
63+
return DescentResult(; δu, success = false, linsolve_success = false)
64+
end
6165
end
6266
else
6367
@assert J!==nothing "`J` must be provided when `pre_inverted = Val(false)`."

src/internal/linear_solve.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ function LinearSolverCache(alg, linsolve, A, b, u; kwargs...)
105105
return LinearSolverCache(lincache, linsolve, nothing, nothing, nothing, precs, 0, 0)
106106
end
107107

108+
@kwdef @concrete struct LinearSolveResult
109+
u
110+
success::Bool = true
111+
end
112+
108113
# Direct Linear Solve Case without Caching
109114
function (cache::LinearSolverCache{Nothing})(;
110115
A = nothing, b = nothing, linu = nothing, kwargs...)
@@ -119,7 +124,7 @@ function (cache::LinearSolverCache{Nothing})(;
119124
else
120125
res = cache.A \ cache.b
121126
end
122-
return res
127+
return LinearSolveResult(; u = res)
123128
end
124129

125130
# Use LinearSolve.jl
@@ -154,11 +159,7 @@ function (cache::LinearSolverCache)(;
154159
cache.lincache.Pr = Pr
155160
end
156161

157-
# display(A)
158-
159162
linres = solve!(cache.lincache)
160-
# @show cache.lincache.cacheval
161-
# @show LinearAlgebra.issuccess(cache.lincache.cacheval)
162163
cache.lincache = linres.cache
163164
# Unfortunately LinearSolve.jl doesn't have the most uniform ReturnCode handling
164165
if linres.retcode === ReturnCode.Failure
@@ -185,11 +186,14 @@ function (cache::LinearSolverCache)(;
185186
end
186187
linres = solve!(cache.additional_lincache)
187188
cache.additional_lincache = linres.cache
188-
return linres.u
189+
linres.retcode === ReturnCode.Failure &&
190+
return LinearSolveResult(; u = linres.u, success = false)
191+
return LinearSolveResult(; u = linres.u)
189192
end
193+
return LinearSolveResult(; u = linres.u, success = false)
190194
end
191195

192-
return linres.u
196+
return LinearSolveResult(; u = linres.u)
193197
end
194198

195199
@inline __update_A!(cache::LinearSolverCache, ::Nothing, reuse) = cache

0 commit comments

Comments
 (0)