Skip to content

Setup NonlinearSolveAlg with jacobian reuse #2727

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/OrdinaryDiffEqCore/src/misc_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ function get_differential_vars(f, u)
end

isnewton(::Any) = false
isnonlinearsolve(::Any) = false

function _bool_to_ADType(::Val{true}, ::Val{CS}, _) where {CS}
Base.depwarn(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ using OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, OrdinaryDiffEqAdaptiveImplici
OrdinaryDiffEqAdaptiveExponentialAlgorithm, @unpack,
AbstractNLSolver, nlsolve_f, issplit,
concrete_jac, unwrap_alg, OrdinaryDiffEqCache, _vec, standardtag,
isnewton, _unwrap_val,
isnewton, isnonlinearsolve, _unwrap_val,
set_new_W!, set_W_γdt!, alg_difftype, unwrap_cache, diffdir,
get_W, isfirstcall, isfirststage, isJcurrent,
get_new_W_γdt_cutoff,
Expand Down
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ function do_newJW(integrator, alg, nlsolver, repeat_step)::NTuple{2, Bool}
return true, true
end
# TODO: add `isJcurrent` support for Rosenbrock solvers
if !isnewton(nlsolver)
if !isnewton(nlsolver) && !isnonlinearsolve(nlsolver)
isfreshJ = !(integrator.alg isa CompositeAlgorithm) &&
(integrator.iter > 1 && errorfail && !integrator.u_modified)
return !isfreshJ, true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ using OrdinaryDiffEqCore: resize_nlsolver!, _initialize_dae!,

import OrdinaryDiffEqCore: _initialize_dae!, isnewton, get_W, isfirstcall, isfirststage,
isJcurrent, get_new_W_γdt_cutoff, resize_nlsolver!, apply_step!,
postamble!
postamble!, isnonlinearsolve

import OrdinaryDiffEqDifferentiation: update_W!, is_always_new, build_uf, build_J_W,
WOperator, StaticWOperator, wrapprecs,
build_jac_config, dolinsolve, alg_autodiff,
resize_jac_config!
resize_jac_config!, do_newJW

import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, SA,
StaticMatrix
Expand Down
19 changes: 17 additions & 2 deletions lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,14 @@ end
@unpack tstep, invγdt = cache

nlcache = nlsolver.cache.cache
step!(nlcache)

if is_always_new(nlsolver) || new_jac || new_W
recompute_jacobian = true
else
recompute_jacobian = false
end

step!(nlcache; recompute_jacobian)
nlsolver.ztmp = nlcache.u

ustep = compute_ustep(tmp, γ, z, method)
Expand All @@ -103,7 +110,15 @@ end
@unpack tstep, invγdt, atmp, ustep = cache

nlcache = nlsolver.cache.cache
step!(nlcache)
new_jac, new_W = do_newJW(integrator, integrator.alg, nlsolver, false)

if is_always_new(nlsolver) || new_jac || new_W
recompute_jacobian = true
else
recompute_jacobian = false
end

step!(nlcache; recompute_jacobian)
@.. broadcast=false ztmp=nlcache.u

ustep = compute_ustep!(ustep, tmp, γ, z, method)
Expand Down
6 changes: 3 additions & 3 deletions lib/OrdinaryDiffEqNonlinearSolve/src/nlsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ function nlsolve!(nlsolver::NL, integrator::DiffEqBase.DEIntegrator,
# check divergence (not in initial step)
if iter > 1
θ = prev_θ = has_prev_θ ? max(0.3 * prev_θ, ndz / ndzprev) : ndz / ndzprev

# When one Newton iteration basically does nothing, it's likely that we
# are at the precision limit of floating point number. Thus, we just call
# it convergence/divergence according to `ndz` directly.
Expand Down Expand Up @@ -105,7 +105,7 @@ function nlsolve!(nlsolver::NL, integrator::DiffEqBase.DEIntegrator,
# don't trust θ for non-adaptive on first iter because the solver doesn't provide feedback
# for us to know whether our previous nlsolve converged sufficiently well
check_η_convergance = (iter > 1 ||
(isnewton(nlsolver) && isadaptive(integrator.alg)))
((isnewton(nlsolver) || isnonlinearsolve(nlsolver)) && isadaptive(integrator.alg)))
if (iter == 1 && ndz < 1e-5) ||
(check_η_convergance && η >= zero(η) && η * ndz < κ)
nlsolver.status = Convergence
Expand All @@ -114,7 +114,7 @@ function nlsolve!(nlsolver::NL, integrator::DiffEqBase.DEIntegrator,
end
end

if isnewton(nlsolver) && nlsolver.status == Divergence &&
if (isnewton(nlsolver) || isnonlinearsolve(nlsolver)) && nlsolver.status == Divergence &&
!isJcurrent(nlsolver, integrator)
nlsolver.status = TryAgain
nlsolver.nfails += 1
Expand Down
1 change: 1 addition & 0 deletions lib/OrdinaryDiffEqNonlinearSolve/src/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,4 +218,5 @@ mutable struct NonlinearSolveCache{uType, tType, rateType, tType2, P, C} <:
invγdt::tType2
prob::P
cache::C
new_W::Bool
end
12 changes: 8 additions & 4 deletions lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ isnewton(nlsolver::AbstractNLSolver) = isnewton(nlsolver.cache)
isnewton(::AbstractNLSolverCache) = false
isnewton(::Union{NLNewtonCache, NLNewtonConstantCache}) = true

isnonlinearsolve(nlsolver::AbstractNLSolver) = isnonlinearsolve(nlsolver.cache)
isnonlinearsolve(::AbstractNLSolverCache) = false
isnonlinearsolve(::NonlinearSolveCache) = true

is_always_new(nlsolver::AbstractNLSolver) = is_always_new(nlsolver.alg)
check_div(nlsolver::AbstractNLSolver) = check_div(nlsolver.alg)
check_div(alg) = isdefined(alg, :check_div) ? alg.check_div : true
Expand All @@ -32,9 +36,9 @@ getnfails(_) = 0
getnfails(nlsolver::AbstractNLSolver) = nlsolver.nfails

set_new_W!(nlsolver::AbstractNLSolver, val::Bool)::Bool = set_new_W!(nlsolver.cache, val)
set_new_W!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache}, val::Bool)::Bool = nlcache.new_W = val
set_new_W!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache, NonlinearSolveCache}, val::Bool)::Bool = nlcache.new_W = val
get_new_W!(nlsolver::AbstractNLSolver)::Bool = get_new_W!(nlsolver.cache)
get_new_W!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache})::Bool = nlcache.new_W
get_new_W!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache, NonlinearSolveCache})::Bool = nlcache.new_W
get_new_W!(::AbstractNLSolverCache)::Bool = true

get_W(nlsolver::AbstractNLSolver) = get_W(nlsolver.cache)
Expand Down Expand Up @@ -231,7 +235,7 @@ function build_nlsolver(
end
prob = NonlinearProblem(NonlinearFunction{true}(nlf), ztmp, nlp_params)
cache = init(prob, nlalg.alg)
nlcache = NonlinearSolveCache(ustep, tstep, k, atmp, invγdt, prob, cache)
nlcache = NonlinearSolveCache(ustep, tstep, k, atmp, invγdt, prob, cache, true)
else
nlcache = NLNewtonCache(ustep, tstep, k, atmp, dz, J, W, true,
true, true, tType(dt), du1, uf, jac_config,
Expand Down Expand Up @@ -316,7 +320,7 @@ function build_nlsolver(
prob = NonlinearProblem(NonlinearFunction{false}(nlf), copy(ztmp), nlp_params)
cache = init(prob, nlalg.alg)
nlcache = NonlinearSolveCache(
nothing, tstep, nothing, nothing, invγdt, prob, cache)
nothing, tstep, nothing, nothing, invγdt, prob, cache, true)
else
nlcache = NLNewtonConstantCache(tstep, J, W, true, true, true, tType(dt), uf,
invγdt, tType(nlalg.new_W_dt_cutoff), t)
Expand Down
Loading