@@ -2,7 +2,9 @@ function generate_ODENLStepData(sys::System, u0, p, mm = calculate_massmatrix(sy
2
2
nlsys, outer_tmp, inner_tmp = inner_nlsystem (sys, mm)
3
3
state = ProblemState (; u = u0, p)
4
4
op = Dict ()
5
- op[ODE_GAMMA] = one (eltype (u0))
5
+ op[ODE_GAMMA[1 ]] = one (eltype (u0))
6
+ op[ODE_GAMMA[2 ]] = one (eltype (u0))
7
+ op[ODE_GAMMA[3 ]] = one (eltype (u0))
6
8
op[ODE_C] = zero (eltype (u0))
7
9
op[outer_tmp] = zeros (eltype (u0), size (outer_tmp))
8
10
op[inner_tmp] = zeros (eltype (u0), size (inner_tmp))
@@ -11,15 +13,17 @@ function generate_ODENLStepData(sys::System, u0, p, mm = calculate_massmatrix(sy
11
13
op[v] = getsym (sys, v)(state)
12
14
end
13
15
nlprob = NonlinearProblem (nlsys, op; build_initializeprob = false )
16
+
17
+ subsetidxs = [findfirst (isequal (y),unknowns (sys)) for y in unknowns (nlsys)]
14
18
set_gamma_c = setsym (nlsys, (ODE_GAMMA... , ODE_C))
15
19
set_outer_tmp = setsym (nlsys, outer_tmp)
16
20
set_inner_tmp = setsym (nlsys, inner_tmp)
17
21
nlprobmap = getsym (nlsys, unknowns (sys))
18
22
19
- return SciMLBase. ODENLStepData (nlprob, set_gamma_c, set_outer_tmp, set_inner_tmp, nlprobmap)
23
+ return SciMLBase. ODENLStepData (nlprob, subsetidxs, set_gamma_c, set_outer_tmp, set_inner_tmp, nlprobmap)
20
24
end
21
25
22
- const ODE_GAMMA = @parameters γ₁ₘₜₖ, γ₂ₘₜₖ
26
+ const ODE_GAMMA = @parameters γ₁ₘₜₖ, γ₂ₘₜₖ, γ₃ₘₜₖ
23
27
const ODE_C = only (@parameters cₘₜₖ)
24
28
25
29
function get_outer_tmp (n:: Int )
@@ -38,19 +42,19 @@ function inner_nlsystem(sys::System, mm)
38
42
@assert length (eqs) == N
39
43
@assert mm == I || size (mm) == (N, N)
40
44
rhss = [eq. rhs for eq in eqs]
41
- gamma1, gamma2 = ODE_GAMMA
45
+ gamma1, gamma2, gamma3 = ODE_GAMMA
42
46
c = ODE_C
43
47
outer_tmp = get_outer_tmp (N)
44
48
inner_tmp = get_inner_tmp (N)
45
49
46
50
subrules = Dict ([v => gamma2* v + inner_tmp[i] for (i, v) in enumerate (dvs)])
47
51
subrules[t] = t + c
48
52
new_rhss = map (Base. Fix2 (fast_substitute, subrules), rhss)
49
- new_rhss = mm * dvs - gamma1 .* new_rhss .+ collect (outer_tmp)
53
+ new_rhss = collect (outer_tmp) .+ gamma1 .* new_rhss .- gamma3 * mm * dvs
50
54
new_eqs = [0 ~ rhs for rhs in new_rhss]
51
55
52
56
new_dvs = unknowns (sys)
53
- new_ps = [parameters (sys); [gamma1, gamma2, c, inner_tmp, outer_tmp]]
57
+ new_ps = [parameters (sys); [gamma1, gamma2, gamma3, c, inner_tmp, outer_tmp]]
54
58
nlsys = mtkcompile (System (new_eqs, new_dvs, new_ps; name = :nlsys ); split = is_split (sys))
55
59
return nlsys, outer_tmp, inner_tmp
56
60
end
0 commit comments