|
| 1 | +function generate_ODENLStepData(sys::System, u0, p, mm = calculate_massmatrix(sys)) |
| 2 | + nlsys, outer_tmp, inner_tmp = inner_nlsystem(sys, mm) |
| 3 | + state = ProblemState(; u = u0, p) |
| 4 | + op = Dict() |
| 5 | + op[ODE_GAMMA[1]] = one(eltype(u0)) |
| 6 | + op[ODE_GAMMA[2]] = one(eltype(u0)) |
| 7 | + op[ODE_GAMMA[3]] = one(eltype(u0)) |
| 8 | + op[ODE_C] = zero(eltype(u0)) |
| 9 | + op[outer_tmp] = zeros(eltype(u0), size(outer_tmp)) |
| 10 | + op[inner_tmp] = zeros(eltype(u0), size(inner_tmp)) |
| 11 | + for v in [unknowns(nlsys); parameters(nlsys)] |
| 12 | + haskey(op, v) && continue |
| 13 | + op[v] = getsym(sys, v)(state) |
| 14 | + end |
| 15 | + nlprob = NonlinearProblem(nlsys, op; build_initializeprob = false) |
| 16 | + |
| 17 | + subsetidxs = [findfirst(isequal(y),unknowns(sys)) for y in unknowns(nlsys)] |
| 18 | + set_gamma_c = setsym(nlsys, (ODE_GAMMA..., ODE_C)) |
| 19 | + set_outer_tmp = setsym(nlsys, outer_tmp) |
| 20 | + set_inner_tmp = setsym(nlsys, inner_tmp) |
| 21 | + nlprobmap = getsym(nlsys, unknowns(sys)) |
| 22 | + |
| 23 | + return SciMLBase.ODENLStepData(nlprob, subsetidxs, set_gamma_c, set_outer_tmp, set_inner_tmp, nlprobmap) |
| 24 | +end |
| 25 | + |
| 26 | +const ODE_GAMMA = @parameters γ₁ₘₜₖ, γ₂ₘₜₖ, γ₃ₘₜₖ |
| 27 | +const ODE_C = only(@parameters cₘₜₖ) |
| 28 | + |
| 29 | +function get_outer_tmp(n::Int) |
| 30 | + only(@parameters outer_tmpₘₜₖ[1:n]) |
| 31 | +end |
| 32 | + |
| 33 | +function get_inner_tmp(n::Int) |
| 34 | + only(@parameters inner_tmpₘₜₖ[1:n]) |
| 35 | +end |
| 36 | + |
| 37 | +function inner_nlsystem(sys::System, mm) |
| 38 | + dvs = unknowns(sys) |
| 39 | + eqs = full_equations(sys) |
| 40 | + t = get_iv(sys) |
| 41 | + N = length(dvs) |
| 42 | + @assert length(eqs) == N |
| 43 | + @assert mm == I || size(mm) == (N, N) |
| 44 | + rhss = [eq.rhs for eq in eqs] |
| 45 | + gamma1, gamma2, gamma3 = ODE_GAMMA |
| 46 | + c = ODE_C |
| 47 | + outer_tmp = get_outer_tmp(N) |
| 48 | + inner_tmp = get_inner_tmp(N) |
| 49 | + |
| 50 | + subrules = Dict([v => gamma2*v + inner_tmp[i] for (i, v) in enumerate(dvs)]) |
| 51 | + subrules[t] = c |
| 52 | + new_rhss = map(Base.Fix2(fast_substitute, subrules), rhss) |
| 53 | + new_rhss = collect(outer_tmp) .+ gamma1 .* new_rhss .- gamma3 * mm * dvs |
| 54 | + new_eqs = [0 ~ rhs for rhs in new_rhss] |
| 55 | + |
| 56 | + new_dvs = unknowns(sys) |
| 57 | + new_ps = [parameters(sys); [gamma1, gamma2, gamma3, c, inner_tmp, outer_tmp]] |
| 58 | + nlsys = mtkcompile(System(new_eqs, new_dvs, new_ps; name = :nlsys); split = is_split(sys)) |
| 59 | + return nlsys, outer_tmp, inner_tmp |
| 60 | +end |
0 commit comments