Skip to content

refactor: update ODE_NLProbData #1067

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 0 additions & 45 deletions src/ODE_nlsolve.jl

This file was deleted.

2 changes: 1 addition & 1 deletion src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ Internal. Used for signifying the AD context comes from a Mooncake.jl context.
struct MooncakeOriginator <: ADOriginator end

include("initialization.jl")
include("ODE_nlsolve.jl")
include("odenlstep.jl")
include("utils.jl")
include("function_wrappers.jl")
include("scimlfunctions.jl")
Expand Down
29 changes: 29 additions & 0 deletions src/odenlstep.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
$(TYPEDEF)

A collection of all the data required for custom ODE Nonlinear problem solving
"""
struct ODENLStepData{NLProb, SetU0, SetGammaC, SetOuterTmp, SetInnerTmp, NLProbMap}
"""
The `AbstractNonlinearProblem` to define custom nonlinear problems to be used for
implicit time discretizations. This allows to use extra structure of the ODE function (e.g.
multi-level structure). The nonlinear function must match that form of the function implicit
ODE integration algorithms need do solve the a nonlinear problems,
specifically of the form `M*z = outer_tmp + γ₁⋅f(γ₂⋅z+inner_tmp,p,t_c)`.
Here `z` is the stage solution vector, `p` is the parameter of the ODE problem, `t_c` is
the time of evaluation (`t_c = t + c*dt`), `γ₁` and `γ₂` are some scaling factors determined
by the solver algorithm and the temporary variables are some compatible vectors set by the specific solver.
The inner nonlinear function of the nonlinear problem is in general of the form `g(z,p') = 0` such that
`g(z,p') = γ₁⋅f(γ₂⋅z+inner_tmp,p,t_c) + outer_tmp - M*z = 0`.
"""
nlprob::NLProb
u0perm::SetU0
set_γ_c::SetGammaC
set_outer_tmp::SetOuterTmp
set_inner_tmp::SetInnerTmp
"""
A function which takes the solution of `nlprob` and returns
the state vector of the original problem.
"""
nlprobmap::NLProbMap
end
50 changes: 25 additions & 25 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ numerically-defined functions.
"""
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
O, TCV,
SYS, ID <: Union{Nothing, OverrideInitData}, NLP <: Union{Nothing, ODE_NLProbData}} <:
SYS, ID <: Union{Nothing, OverrideInitData}, NLP <: Union{Nothing, ODENLStepData}} <:
AbstractODEFunction{iip}
f::F
mass_matrix::TMM
Expand All @@ -428,7 +428,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
colorvec::TCV
sys::SYS
initialization_data::ID
nlprob_data::NLP
nlstep_data::NLP
end

@doc doc"""
Expand Down Expand Up @@ -532,7 +532,7 @@ information on generating the SplitFunction from this symbolic engine.
struct SplitFunction{
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, WP, SP, TW, TWt,
TPJ, O, TCV, SYS, ID <: Union{Nothing, OverrideInitData},
NLP <: Union{Nothing, ODE_NLProbData}} <: AbstractODEFunction{iip}
NLP <: Union{Nothing, ODENLStepData}} <: AbstractODEFunction{iip}
f1::F1
f2::F2
mass_matrix::TMM
Expand All @@ -552,7 +552,7 @@ struct SplitFunction{
colorvec::TCV
sys::SYS
initialization_data::ID
nlprob_data::NLP
nlstep_data::NLP
end

@doc doc"""
Expand Down Expand Up @@ -2691,7 +2691,7 @@ function ODEFunction{iip, specialize}(f;
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
initialization_data = __has_initialization_data(f) ? f.initialization_data :
nothing,
nlprob_data = __has_nlprob_data(f) ? f.nlprob_data : nothing
nlstep_data = __has_nlstep_data(f) ? f.nlstep_data : nothing
) where {iip,
specialize
}
Expand Down Expand Up @@ -2749,11 +2749,11 @@ function ODEFunction{iip, specialize}(f;
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
Any,
typeof(_colorvec),
typeof(sys), Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(
typeof(sys), Union{Nothing, OverrideInitData}, Union{Nothing, ODENLStepData}}(
_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initdata, nlprob_data)
observed, _colorvec, sys, initdata, nlstep_data)
elseif specialize === false
ODEFunction{iip, FunctionWrapperSpecialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
Expand All @@ -2762,11 +2762,11 @@ function ODEFunction{iip, specialize}(f;
typeof(paramjac),
typeof(observed),
typeof(_colorvec),
typeof(sys), typeof(initdata), typeof(nlprob_data)}(_f, mass_matrix,
typeof(sys), typeof(initdata), typeof(nlstep_data)}(_f, mass_matrix,
analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initdata, nlprob_data)
observed, _colorvec, sys, initdata, nlstep_data)
else
ODEFunction{iip, specialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
Expand All @@ -2775,11 +2775,11 @@ function ODEFunction{iip, specialize}(f;
typeof(paramjac),
typeof(observed),
typeof(_colorvec),
typeof(sys), typeof(initdata), typeof(nlprob_data)}(
typeof(sys), typeof(initdata), typeof(nlstep_data)}(
_f, mass_matrix, analytic, tgrad,
jac, jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initdata, nlprob_data)
observed, _colorvec, sys, initdata, nlstep_data)
end
end

Expand All @@ -2796,23 +2796,23 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
Any, Any, Any, Any, typeof(f.jac_prototype),
typeof(f.sparsity), Any, Any, Any, Any,
Any, typeof(f.colorvec),
typeof(f.sys), Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(
typeof(f.sys), Union{Nothing, OverrideInitData}, Union{Nothing, ODENLStepData}}(
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
f.Wfact_t, f.W_prototype, f.paramjac,
f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob_data)
f.observed, f.colorvec, f.sys, f.initialization_data, f.nlstep_data)
else
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
typeof(f.analytic), typeof(f.tgrad),
typeof(f.jac), typeof(f.jvp), typeof(f.vjp), typeof(f.jac_prototype),
typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.W_prototype),
typeof(f.paramjac),
typeof(f.observed), typeof(f.colorvec),
typeof(f.sys), typeof(f.initialization_data), typeof(f.nlprob_data)}(
typeof(f.sys), typeof(f.initialization_data), typeof(f.nlstep_data)}(
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
f.Wfact_t, f.W_prototype, f.paramjac,
f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob_data)
f.observed, f.colorvec, f.sys, f.initialization_data, f.nlstep_data)
end
end

Expand Down Expand Up @@ -2948,7 +2948,7 @@ end
f1, f2, mass_matrix, _func_cache, analytic, tgrad, jac, jvp,
vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac,
observed, colorvec, sys, initializeprob = nothing, update_initializeprob! = nothing,
initializeprobmap = nothing, initializeprobpmap = nothing, initialization_data = nothing, nlprob_data = nothing)
initializeprobmap = nothing, initializeprobpmap = nothing, initialization_data = nothing, nlstep_data = nothing)
f1 = ODEFunction(f1)
f2 = ODEFunction(f2)

Expand All @@ -2966,11 +2966,11 @@ end
typeof(_func_cache), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp),
typeof(vjp), typeof(jac_prototype), typeof(W_prototype), typeof(sparsity),
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec),
typeof(sys), typeof(initdata), typeof(nlprob_data)}(
typeof(sys), typeof(initdata), typeof(nlstep_data)}(
f1, f2, mass_matrix,
_func_cache, analytic, tgrad, jac, jvp, vjp,
jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
initdata, nlprob_data)
initdata, nlstep_data)
end
function SplitFunction{iip, specialize}(f1, f2;
mass_matrix = __has_mass_matrix(f1) ?
Expand Down Expand Up @@ -3007,7 +3007,7 @@ function SplitFunction{iip, specialize}(f1, f2;
f1.update_initializeprob! : nothing,
initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing,
initializeprobpmap = __has_initializeprobpmap(f1) ? f1.initializeprobpmap : nothing,
nlprob_data = __has_nlprob_data(f1) ? f1.nlprob_data : nothing,
nlstep_data = __has_nlstep_data(f1) ? f1.nlstep_data : nothing,
initialization_data = __has_initialization_data(f1) ? f1.initialization_data :
nothing
) where {iip,
Expand All @@ -3021,24 +3021,24 @@ function SplitFunction{iip, specialize}(f1, f2;
if specialize === NoSpecialize
SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any, Any,
Any, Any, Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(
Any, Any, Union{Nothing, OverrideInitData}, Union{Nothing, ODENLStepData}}(
f1, f2, mass_matrix, _func_cache,
analytic,
tgrad, jac, jvp, vjp, jac_prototype, W_prototype,
sparsity, Wfact, Wfact_t, paramjac,
observed, colorvec, sys, initdata, nlprob_data)
observed, colorvec, sys, initdata, nlstep_data)
else
SplitFunction{iip, specialize, typeof(f1), typeof(f2), typeof(mass_matrix),
typeof(_func_cache), typeof(analytic),
typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp),
typeof(jac_prototype), typeof(W_prototype), typeof(sparsity),
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed),
typeof(colorvec),
typeof(sys), typeof(initdata), typeof(nlprob_data)}(f1, f2,
typeof(sys), typeof(initdata), typeof(nlstep_data)}(f1, f2,
mass_matrix, _func_cache, analytic, tgrad, jac,
jvp, vjp, jac_prototype, W_prototype,
sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
initdata, nlprob_data)
initdata, nlstep_data)
end
end

Expand Down Expand Up @@ -4779,7 +4779,7 @@ function ODEInputFunction{iip, specialize}(f;
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
initialization_data = __has_initialization_data(f) ? f.initialization_data :
nothing,
nlprob_data = __has_nlprob_data(f) ? f.nlprob_data : nothing
nlstep_data = __has_nlstep_data(f) ? f.nlstep_data : nothing
) where {iip,
specialize
}
Expand Down Expand Up @@ -4938,7 +4938,7 @@ __has_colorvec(f) = isdefined(f, :colorvec)
__has_sys(f) = isdefined(f, :sys)
__has_analytic_full(f) = isdefined(f, :analytic_full)
__has_resid_prototype(f) = isdefined(f, :resid_prototype)
__has_nlprob_data(f) = isdefined(f, :nlprob_data)
__has_nlstep_data(f) = isdefined(f, :nlstep_data)
function __has_initializeprob(f)
has_initialization_data(f) && isdefined(f.initialization_data, :initializeprob)
end
Expand Down
Loading