Skip to content

Commit 0e578fe

Browse files
Merge pull request #3827 from AayushSabharwal/as/ode-nlprob
feat: generate `SciMLBase.ODE_NLProbData`
2 parents 2033f28 + 9db32a9 commit 0e578fe

File tree

3 files changed

+70
-2
lines changed

3 files changed

+70
-2
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ include("problems/docs.jl")
176176
include("systems/codegen.jl")
177177
include("systems/problem_utils.jl")
178178
include("linearization.jl")
179+
include("systems/solver_nlprob.jl")
179180

180181
include("problems/compatibility.jl")
181182
include("problems/odeproblem.jl")

src/problems/odeproblem.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
t = nothing, eval_expression = false, eval_module = @__MODULE__, sparse = false,
44
steady_state = false, checkbounds = false, sparsity = false, analytic = nothing,
55
simplify = false, cse = true, initialization_data = nothing, expression = Val{false},
6-
check_compatibility = true, kwargs...) where {iip, spec}
6+
check_compatibility = true, nlstep = false, kwargs...) where {iip, spec}
77
check_complete(sys, ODEFunction)
88
check_compatibility && check_compatible_system(ODEFunction, sys)
99

@@ -41,6 +41,12 @@
4141
M = calculate_massmatrix(sys)
4242
_M = concrete_massmatrix(M; sparse, u0)
4343

44+
if nlstep
45+
ode_nlstep = generate_ODENLStepData(sys, u0, p, M)
46+
else
47+
ode_nlstep = nothing
48+
end
49+
4450
observedfun = ObservedFunctionCache(
4551
sys; expression, steady_state, eval_expression, eval_module, checkbounds, cse)
4652

@@ -57,7 +63,8 @@
5763
observed = observedfun,
5864
sparsity = sparsity ? _W_sparsity : nothing,
5965
analytic = analytic,
60-
initialization_data)
66+
initialization_data,
67+
nlstep_data = ode_nlstep)
6168

6269
maybe_codegen_scimlfn(expression, ODEFunction{iip, spec}, args; kwargs...)
6370
end

src/systems/solver_nlprob.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
function generate_ODENLStepData(sys::System, u0, p, mm = calculate_massmatrix(sys))
2+
nlsys, outer_tmp, inner_tmp = inner_nlsystem(sys, mm)
3+
state = ProblemState(; u = u0, p)
4+
op = Dict()
5+
op[ODE_GAMMA[1]] = one(eltype(u0))
6+
op[ODE_GAMMA[2]] = one(eltype(u0))
7+
op[ODE_GAMMA[3]] = one(eltype(u0))
8+
op[ODE_C] = zero(eltype(u0))
9+
op[outer_tmp] = zeros(eltype(u0), size(outer_tmp))
10+
op[inner_tmp] = zeros(eltype(u0), size(inner_tmp))
11+
for v in [unknowns(nlsys); parameters(nlsys)]
12+
haskey(op, v) && continue
13+
op[v] = getsym(sys, v)(state)
14+
end
15+
nlprob = NonlinearProblem(nlsys, op; build_initializeprob = false)
16+
17+
subsetidxs = [findfirst(isequal(y),unknowns(sys)) for y in unknowns(nlsys)]
18+
set_gamma_c = setsym(nlsys, (ODE_GAMMA..., ODE_C))
19+
set_outer_tmp = setsym(nlsys, outer_tmp)
20+
set_inner_tmp = setsym(nlsys, inner_tmp)
21+
nlprobmap = getsym(nlsys, unknowns(sys))
22+
23+
return SciMLBase.ODENLStepData(nlprob, subsetidxs, set_gamma_c, set_outer_tmp, set_inner_tmp, nlprobmap)
24+
end
25+
26+
const ODE_GAMMA = @parameters γ₁ₘₜₖ, γ₂ₘₜₖ, γ₃ₘₜₖ
27+
const ODE_C = only(@parameters cₘₜₖ)
28+
29+
function get_outer_tmp(n::Int)
30+
only(@parameters outer_tmpₘₜₖ[1:n])
31+
end
32+
33+
function get_inner_tmp(n::Int)
34+
only(@parameters inner_tmpₘₜₖ[1:n])
35+
end
36+
37+
function inner_nlsystem(sys::System, mm)
38+
dvs = unknowns(sys)
39+
eqs = full_equations(sys)
40+
t = get_iv(sys)
41+
N = length(dvs)
42+
@assert length(eqs) == N
43+
@assert mm == I || size(mm) == (N, N)
44+
rhss = [eq.rhs for eq in eqs]
45+
gamma1, gamma2, gamma3 = ODE_GAMMA
46+
c = ODE_C
47+
outer_tmp = get_outer_tmp(N)
48+
inner_tmp = get_inner_tmp(N)
49+
50+
subrules = Dict([v => gamma2*v + inner_tmp[i] for (i, v) in enumerate(dvs)])
51+
subrules[t] = c
52+
new_rhss = map(Base.Fix2(fast_substitute, subrules), rhss)
53+
new_rhss = collect(outer_tmp) .+ gamma1 .* new_rhss .- gamma3 * mm * dvs
54+
new_eqs = [0 ~ rhs for rhs in new_rhss]
55+
56+
new_dvs = unknowns(sys)
57+
new_ps = [parameters(sys); [gamma1, gamma2, gamma3, c, inner_tmp, outer_tmp]]
58+
nlsys = mtkcompile(System(new_eqs, new_dvs, new_ps; name = :nlsys); split = is_split(sys))
59+
return nlsys, outer_tmp, inner_tmp
60+
end

0 commit comments

Comments
 (0)