Skip to content

Commit 4c5e6d9

Browse files
expand gamma
1 parent 94aa324 commit 4c5e6d9

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

src/systems/solver_nlprob.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ function generate_ODENLStepData(sys::System, u0, p, mm = calculate_massmatrix(sy
22
nlsys, outer_tmp, inner_tmp = inner_nlsystem(sys, mm)
33
state = ProblemState(; u = u0, p)
44
op = Dict()
5-
op[ODE_GAMMA] = one(eltype(u0))
5+
op[ODE_GAMMA[1]] = one(eltype(u0))
6+
op[ODE_GAMMA[2]] = one(eltype(u0))
7+
op[ODE_GAMMA[3]] = one(eltype(u0))
68
op[ODE_C] = zero(eltype(u0))
79
op[outer_tmp] = zeros(eltype(u0), size(outer_tmp))
810
op[inner_tmp] = zeros(eltype(u0), size(inner_tmp))
@@ -11,15 +13,17 @@ function generate_ODENLStepData(sys::System, u0, p, mm = calculate_massmatrix(sy
1113
op[v] = getsym(sys, v)(state)
1214
end
1315
nlprob = NonlinearProblem(nlsys, op; build_initializeprob = false)
16+
17+
subsetidxs = [findfirst(isequal(y),unknowns(sys)) for y in unknowns(nlsys)]
1418
set_gamma_c = setsym(nlsys, (ODE_GAMMA..., ODE_C))
1519
set_outer_tmp = setsym(nlsys, outer_tmp)
1620
set_inner_tmp = setsym(nlsys, inner_tmp)
1721
nlprobmap = getsym(nlsys, unknowns(sys))
1822

19-
return SciMLBase.ODENLStepData(nlprob, set_gamma_c, set_outer_tmp, set_inner_tmp, nlprobmap)
23+
return SciMLBase.ODENLStepData(nlprob, subsetidxs, set_gamma_c, set_outer_tmp, set_inner_tmp, nlprobmap)
2024
end
2125

22-
const ODE_GAMMA = @parameters γ₁ₘₜₖ, γ₂ₘₜₖ
26+
const ODE_GAMMA = @parameters γ₁ₘₜₖ, γ₂ₘₜₖ, γ₃ₘₜₖ
2327
const ODE_C = only(@parameters cₘₜₖ)
2428

2529
function get_outer_tmp(n::Int)
@@ -38,19 +42,19 @@ function inner_nlsystem(sys::System, mm)
3842
@assert length(eqs) == N
3943
@assert mm == I || size(mm) == (N, N)
4044
rhss = [eq.rhs for eq in eqs]
41-
gamma1, gamma2 = ODE_GAMMA
45+
gamma1, gamma2, gamma3 = ODE_GAMMA
4246
c = ODE_C
4347
outer_tmp = get_outer_tmp(N)
4448
inner_tmp = get_inner_tmp(N)
4549

4650
subrules = Dict([v => gamma2*v + inner_tmp[i] for (i, v) in enumerate(dvs)])
4751
subrules[t] = t + c
4852
new_rhss = map(Base.Fix2(fast_substitute, subrules), rhss)
49-
new_rhss = mm * dvs - gamma1 .* new_rhss .+ collect(outer_tmp)
53+
new_rhss = collect(outer_tmp) .+ gamma1 .* new_rhss .- gamma3 * mm * dvs
5054
new_eqs = [0 ~ rhs for rhs in new_rhss]
5155

5256
new_dvs = unknowns(sys)
53-
new_ps = [parameters(sys); [gamma1, gamma2, c, inner_tmp, outer_tmp]]
57+
new_ps = [parameters(sys); [gamma1, gamma2, gamma3, c, inner_tmp, outer_tmp]]
5458
nlsys = mtkcompile(System(new_eqs, new_dvs, new_ps; name = :nlsys); split = is_split(sys))
5559
return nlsys, outer_tmp, inner_tmp
5660
end

0 commit comments

Comments
 (0)