diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index c81774946d..2c259058b0 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -176,6 +176,7 @@ include("problems/docs.jl") include("systems/codegen.jl") include("systems/problem_utils.jl") include("linearization.jl") +include("systems/solver_nlprob.jl") include("problems/compatibility.jl") include("problems/odeproblem.jl") diff --git a/src/problems/odeproblem.jl b/src/problems/odeproblem.jl index da33963be6..68a38c95cf 100644 --- a/src/problems/odeproblem.jl +++ b/src/problems/odeproblem.jl @@ -3,7 +3,7 @@ t = nothing, eval_expression = false, eval_module = @__MODULE__, sparse = false, steady_state = false, checkbounds = false, sparsity = false, analytic = nothing, simplify = false, cse = true, initialization_data = nothing, expression = Val{false}, - check_compatibility = true, kwargs...) where {iip, spec} + check_compatibility = true, nlstep = false, kwargs...) where {iip, spec} check_complete(sys, ODEFunction) check_compatibility && check_compatible_system(ODEFunction, sys) @@ -41,6 +41,12 @@ M = calculate_massmatrix(sys) _M = concrete_massmatrix(M; sparse, u0) + if nlstep + ode_nlstep = generate_ODENLStepData(sys, u0, p, M) + else + ode_nlstep = nothing + end + observedfun = ObservedFunctionCache( sys; expression, steady_state, eval_expression, eval_module, checkbounds, cse) @@ -57,7 +63,8 @@ observed = observedfun, sparsity = sparsity ? _W_sparsity : nothing, analytic = analytic, - initialization_data) + initialization_data, + nlstep_data = ode_nlstep) maybe_codegen_scimlfn(expression, ODEFunction{iip, spec}, args; kwargs...) end diff --git a/src/systems/solver_nlprob.jl b/src/systems/solver_nlprob.jl new file mode 100644 index 0000000000..006241bf44 --- /dev/null +++ b/src/systems/solver_nlprob.jl @@ -0,0 +1,60 @@ +function generate_ODENLStepData(sys::System, u0, p, mm = calculate_massmatrix(sys)) + nlsys, outer_tmp, inner_tmp = inner_nlsystem(sys, mm) + state = ProblemState(; u = u0, p) + op = Dict() + op[ODE_GAMMA[1]] = one(eltype(u0)) + op[ODE_GAMMA[2]] = one(eltype(u0)) + op[ODE_GAMMA[3]] = one(eltype(u0)) + op[ODE_C] = zero(eltype(u0)) + op[outer_tmp] = zeros(eltype(u0), size(outer_tmp)) + op[inner_tmp] = zeros(eltype(u0), size(inner_tmp)) + for v in [unknowns(nlsys); parameters(nlsys)] + haskey(op, v) && continue + op[v] = getsym(sys, v)(state) + end + nlprob = NonlinearProblem(nlsys, op; build_initializeprob = false) + + subsetidxs = [findfirst(isequal(y),unknowns(sys)) for y in unknowns(nlsys)] + set_gamma_c = setsym(nlsys, (ODE_GAMMA..., ODE_C)) + set_outer_tmp = setsym(nlsys, outer_tmp) + set_inner_tmp = setsym(nlsys, inner_tmp) + nlprobmap = getsym(nlsys, unknowns(sys)) + + return SciMLBase.ODENLStepData(nlprob, subsetidxs, set_gamma_c, set_outer_tmp, set_inner_tmp, nlprobmap) +end + +const ODE_GAMMA = @parameters γ₁ₘₜₖ, γ₂ₘₜₖ, γ₃ₘₜₖ +const ODE_C = only(@parameters cₘₜₖ) + +function get_outer_tmp(n::Int) + only(@parameters outer_tmpₘₜₖ[1:n]) +end + +function get_inner_tmp(n::Int) + only(@parameters inner_tmpₘₜₖ[1:n]) +end + +function inner_nlsystem(sys::System, mm) + dvs = unknowns(sys) + eqs = full_equations(sys) + t = get_iv(sys) + N = length(dvs) + @assert length(eqs) == N + @assert mm == I || size(mm) == (N, N) + rhss = [eq.rhs for eq in eqs] + gamma1, gamma2, gamma3 = ODE_GAMMA + c = ODE_C + outer_tmp = get_outer_tmp(N) + inner_tmp = get_inner_tmp(N) + + subrules = Dict([v => gamma2*v + inner_tmp[i] for (i, v) in enumerate(dvs)]) + subrules[t] = c + new_rhss = map(Base.Fix2(fast_substitute, subrules), rhss) + new_rhss = collect(outer_tmp) .+ gamma1 .* new_rhss .- gamma3 * mm * dvs + new_eqs = [0 ~ rhs for rhs in new_rhss] + + new_dvs = unknowns(sys) + new_ps = [parameters(sys); [gamma1, gamma2, gamma3, c, inner_tmp, outer_tmp]] + nlsys = mtkcompile(System(new_eqs, new_dvs, new_ps; name = :nlsys); split = is_split(sys)) + return nlsys, outer_tmp, inner_tmp +end