From 53963d16b9e574673606bd0adc8a7757f0e49d02 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 28 May 2025 00:27:22 +0000 Subject: [PATCH] Setup NonlinearSolveAlg with jacobian reuse --- lib/OrdinaryDiffEqCore/src/misc_utils.jl | 1 + .../src/OrdinaryDiffEqDifferentiation.jl | 2 +- .../src/derivative_utils.jl | 2 +- .../src/OrdinaryDiffEqNonlinearSolve.jl | 4 ++-- .../src/newton.jl | 19 +++++++++++++++++-- .../src/nlsolve.jl | 6 +++--- lib/OrdinaryDiffEqNonlinearSolve/src/type.jl | 1 + lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl | 12 ++++++++---- 8 files changed, 34 insertions(+), 13 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/misc_utils.jl b/lib/OrdinaryDiffEqCore/src/misc_utils.jl index 8b70f1e046..32072b32a0 100644 --- a/lib/OrdinaryDiffEqCore/src/misc_utils.jl +++ b/lib/OrdinaryDiffEqCore/src/misc_utils.jl @@ -133,6 +133,7 @@ function get_differential_vars(f, u) end isnewton(::Any) = false +isnonlinearsolve(::Any) = false function _bool_to_ADType(::Val{true}, ::Val{CS}, _) where {CS} Base.depwarn( diff --git a/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl b/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl index b96839d07c..4a3bd4a5bf 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl @@ -35,7 +35,7 @@ using OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, OrdinaryDiffEqAdaptiveImplici OrdinaryDiffEqAdaptiveExponentialAlgorithm, @unpack, AbstractNLSolver, nlsolve_f, issplit, concrete_jac, unwrap_alg, OrdinaryDiffEqCache, _vec, standardtag, - isnewton, _unwrap_val, + isnewton, isnonlinearsolve, _unwrap_val, set_new_W!, set_W_γdt!, alg_difftype, unwrap_cache, diffdir, get_W, isfirstcall, isfirststage, isJcurrent, get_new_W_γdt_cutoff, diff --git a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl index 497fc18660..cc66d31787 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl @@ -458,7 +458,7 @@ function do_newJW(integrator, alg, nlsolver, repeat_step)::NTuple{2, Bool} return true, true end # TODO: add `isJcurrent` support for Rosenbrock solvers - if !isnewton(nlsolver) + if !isnewton(nlsolver) && !isnonlinearsolve(nlsolver) isfreshJ = !(integrator.alg isa CompositeAlgorithm) && (integrator.iter > 1 && errorfail && !integrator.u_modified) return !isfreshJ, true diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/OrdinaryDiffEqNonlinearSolve.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/OrdinaryDiffEqNonlinearSolve.jl index a04d64f93f..35281a5edc 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/OrdinaryDiffEqNonlinearSolve.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/OrdinaryDiffEqNonlinearSolve.jl @@ -48,12 +48,12 @@ using OrdinaryDiffEqCore: resize_nlsolver!, _initialize_dae!, import OrdinaryDiffEqCore: _initialize_dae!, isnewton, get_W, isfirstcall, isfirststage, isJcurrent, get_new_W_γdt_cutoff, resize_nlsolver!, apply_step!, - postamble! + postamble!, isnonlinearsolve import OrdinaryDiffEqDifferentiation: update_W!, is_always_new, build_uf, build_J_W, WOperator, StaticWOperator, wrapprecs, build_jac_config, dolinsolve, alg_autodiff, - resize_jac_config! + resize_jac_config!, do_newJW import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, SA, StaticMatrix diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl index 237289c011..1afd513752 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl @@ -81,7 +81,14 @@ end @unpack tstep, invγdt = cache nlcache = nlsolver.cache.cache - step!(nlcache) + + if is_always_new(nlsolver) || new_jac || new_W + recompute_jacobian = true + else + recompute_jacobian = false + end + + step!(nlcache; recompute_jacobian) nlsolver.ztmp = nlcache.u ustep = compute_ustep(tmp, γ, z, method) @@ -103,7 +110,15 @@ end @unpack tstep, invγdt, atmp, ustep = cache nlcache = nlsolver.cache.cache - step!(nlcache) + new_jac, new_W = do_newJW(integrator, integrator.alg, nlsolver, false) + + if is_always_new(nlsolver) || new_jac || new_W + recompute_jacobian = true + else + recompute_jacobian = false + end + + step!(nlcache; recompute_jacobian) @.. broadcast=false ztmp=nlcache.u ustep = compute_ustep!(ustep, tmp, γ, z, method) diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/nlsolve.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/nlsolve.jl index 1633531bdc..2af35f1cdc 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/nlsolve.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/nlsolve.jl @@ -65,7 +65,7 @@ function nlsolve!(nlsolver::NL, integrator::DiffEqBase.DEIntegrator, # check divergence (not in initial step) if iter > 1 θ = prev_θ = has_prev_θ ? max(0.3 * prev_θ, ndz / ndzprev) : ndz / ndzprev - + # When one Newton iteration basically does nothing, it's likely that we # are at the precision limit of floating point number. Thus, we just call # it convergence/divergence according to `ndz` directly. @@ -105,7 +105,7 @@ function nlsolve!(nlsolver::NL, integrator::DiffEqBase.DEIntegrator, # don't trust θ for non-adaptive on first iter because the solver doesn't provide feedback # for us to know whether our previous nlsolve converged sufficiently well check_η_convergance = (iter > 1 || - (isnewton(nlsolver) && isadaptive(integrator.alg))) + ((isnewton(nlsolver) || isnonlinearsolve(nlsolver)) && isadaptive(integrator.alg))) if (iter == 1 && ndz < 1e-5) || (check_η_convergance && η >= zero(η) && η * ndz < κ) nlsolver.status = Convergence @@ -114,7 +114,7 @@ function nlsolve!(nlsolver::NL, integrator::DiffEqBase.DEIntegrator, end end - if isnewton(nlsolver) && nlsolver.status == Divergence && + if (isnewton(nlsolver) || isnonlinearsolve(nlsolver)) && nlsolver.status == Divergence && !isJcurrent(nlsolver, integrator) nlsolver.status = TryAgain nlsolver.nfails += 1 diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/type.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/type.jl index 7b41c54903..05d89a349b 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/type.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/type.jl @@ -218,4 +218,5 @@ mutable struct NonlinearSolveCache{uType, tType, rateType, tType2, P, C} <: invγdt::tType2 prob::P cache::C + new_W::Bool end diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl index 485c0b56c6..4c71ca3333 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl @@ -14,6 +14,10 @@ isnewton(nlsolver::AbstractNLSolver) = isnewton(nlsolver.cache) isnewton(::AbstractNLSolverCache) = false isnewton(::Union{NLNewtonCache, NLNewtonConstantCache}) = true +isnonlinearsolve(nlsolver::AbstractNLSolver) = isnonlinearsolve(nlsolver.cache) +isnonlinearsolve(::AbstractNLSolverCache) = false +isnonlinearsolve(::NonlinearSolveCache) = true + is_always_new(nlsolver::AbstractNLSolver) = is_always_new(nlsolver.alg) check_div(nlsolver::AbstractNLSolver) = check_div(nlsolver.alg) check_div(alg) = isdefined(alg, :check_div) ? alg.check_div : true @@ -32,9 +36,9 @@ getnfails(_) = 0 getnfails(nlsolver::AbstractNLSolver) = nlsolver.nfails set_new_W!(nlsolver::AbstractNLSolver, val::Bool)::Bool = set_new_W!(nlsolver.cache, val) -set_new_W!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache}, val::Bool)::Bool = nlcache.new_W = val +set_new_W!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache, NonlinearSolveCache}, val::Bool)::Bool = nlcache.new_W = val get_new_W!(nlsolver::AbstractNLSolver)::Bool = get_new_W!(nlsolver.cache) -get_new_W!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache})::Bool = nlcache.new_W +get_new_W!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache, NonlinearSolveCache})::Bool = nlcache.new_W get_new_W!(::AbstractNLSolverCache)::Bool = true get_W(nlsolver::AbstractNLSolver) = get_W(nlsolver.cache) @@ -231,7 +235,7 @@ function build_nlsolver( end prob = NonlinearProblem(NonlinearFunction{true}(nlf), ztmp, nlp_params) cache = init(prob, nlalg.alg) - nlcache = NonlinearSolveCache(ustep, tstep, k, atmp, invγdt, prob, cache) + nlcache = NonlinearSolveCache(ustep, tstep, k, atmp, invγdt, prob, cache, true) else nlcache = NLNewtonCache(ustep, tstep, k, atmp, dz, J, W, true, true, true, tType(dt), du1, uf, jac_config, @@ -316,7 +320,7 @@ function build_nlsolver( prob = NonlinearProblem(NonlinearFunction{false}(nlf), copy(ztmp), nlp_params) cache = init(prob, nlalg.alg) nlcache = NonlinearSolveCache( - nothing, tstep, nothing, nothing, invγdt, prob, cache) + nothing, tstep, nothing, nothing, invγdt, prob, cache, true) else nlcache = NLNewtonConstantCache(tstep, J, W, true, true, true, tType(dt), uf, invγdt, tType(nlalg.new_W_dt_cutoff), t)