Skip to content

Commit 4546474

Browse files
feat: add SemilinearODEFunction and SemilinearODEProblem
1 parent cdc7071 commit 4546474

File tree

4 files changed

+630
-1
lines changed

4 files changed

+630
-1
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
2626
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
2727
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
2828
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
29+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
2930
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
3031
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3132
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
@@ -45,6 +46,7 @@ NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
4546
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
4647
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
4748
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
49+
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
4850
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
4951
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
5052
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -115,6 +117,7 @@ DynamicQuantities = "^0.11.2, 0.12, 0.13, 1"
115117
EnumX = "1.0.4"
116118
ExprTools = "0.1.10"
117119
FMI = "0.14"
120+
FillArrays = "1.13.0"
118121
FindFirstFunctions = "1"
119122
ForwardDiff = "0.10.3"
120123
FunctionWrappers = "1.1"
@@ -142,6 +145,7 @@ OrdinaryDiffEq = "6.82.0"
142145
OrdinaryDiffEqCore = "1.15.0"
143146
OrdinaryDiffEqDefault = "1.2"
144147
OrdinaryDiffEqNonlinearSolve = "1.5.0"
148+
PreallocationTools = "0.4.27"
145149
PrecompileTools = "1"
146150
Pyomo = "0.1.0"
147151
REPL = "1"

src/ModelingToolkit.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ const DQ = DynamicQuantities
9999
import DifferentiationInterface as DI
100100
using ADTypes: AutoForwardDiff
101101
import SciMLPublic: @public
102+
import PreallocationTools
103+
import PreallocationTools: DiffCache
104+
import FillArrays
102105

103106
export @derivatives
104107

@@ -288,6 +291,7 @@ export IntervalNonlinearProblem
288291
export OptimizationProblem, constraints
289292
export SteadyStateProblem
290293
export JumpProblem
294+
export SemilinearODEFunction, SemilinearODEProblem
291295
export alias_elimination, flatten
292296
export connect, domain_connect, @connector, Connection, AnalysisPoint, Flow, Stream,
293297
instream

src/problems/odeproblem.jl

Lines changed: 151 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,159 @@ end
9898
maybe_codegen_scimlproblem(expression, SteadyStateProblem{iip}, args; kwargs...)
9999
end
100100

101+
@fallback_iip_specialize function SemilinearODEFunction{iip, specialize}(
102+
sys::System; u0 = nothing, p = nothing, t = nothing,
103+
semiquadratic_form = nothing,
104+
stiff_A = true, stiff_B = false, stiff_C = false,
105+
eval_expression = false, eval_module = @__MODULE__,
106+
expression = Val{false}, sparse = false, check_compatibility = true,
107+
jac = false, checkbounds = false, cse = true, initialization_data = nothing,
108+
analytic = nothing, kwargs...) where {iip, specialize}
109+
check_complete(sys, SemilinearODEFunction)
110+
check_compatibility && check_compatible_system(SemilinearODEFunction, sys)
111+
112+
if semiquadratic_form === nothing
113+
semiquadratic_form = calculate_semiquadratic_form(sys; sparse)
114+
sys = add_semiquadratic_parameters(sys, semiquadratic_form...)
115+
end
116+
117+
A, B, C = semiquadratic_form
118+
M = calculate_massmatrix(sys)
119+
_M = concrete_massmatrix(M; sparse, u0)
120+
dvs = unknowns(sys)
121+
122+
f1, f2 = generate_semiquadratic_functions(
123+
sys, A, B, C; stiff_A, stiff_B, stiff_C, expression, wrap_gfw = Val{true},
124+
eval_expression, eval_module, kwargs...)
125+
126+
if jac
127+
Cjac = (C === nothing || !stiff_C) ? nothing : Symbolics.jacobian(C, dvs)
128+
_jac = generate_semiquadratic_jacobian(
129+
sys, A, B, C, Cjac; sparse, expression,
130+
wrap_gfw = Val{true}, eval_expression, eval_module, kwargs...)
131+
_W_sparsity = get_semiquadratic_W_sparsity(
132+
sys, A, B, C, Cjac; stiff_A, stiff_B, stiff_C, mm = M)
133+
W_prototype = calculate_W_prototype(_W_sparsity; u0, sparse)
134+
else
135+
_jac = nothing
136+
W_prototype = nothing
137+
end
138+
139+
observedfun = ObservedFunctionCache(
140+
sys; expression, steady_state = false, eval_expression, eval_module, checkbounds, cse)
141+
142+
args = (; f1)
143+
kwargs = (; jac = _jac, jac_prototype = W_prototype)
144+
f1 = maybe_codegen_scimlfn(expression, ODEFunction{iip, specialize}, args; kwargs...)
145+
146+
args = (; f1, f2)
147+
kwargs = (;
148+
sys = sys,
149+
jac = _jac,
150+
mass_matrix = _M,
151+
jac_prototype = W_prototype,
152+
observed = observedfun,
153+
analytic,
154+
initialization_data)
155+
156+
return maybe_codegen_scimlfn(
157+
expression, SplitFunction{iip, specialize}, args; kwargs...)
158+
end
159+
160+
@fallback_iip_specialize function SemilinearODEProblem{iip, spec}(
161+
sys::System, op, tspan; check_compatibility = true, u0_eltype = nothing,
162+
expression = Val{false}, callback = nothing, sparse = false,
163+
stiff_A = true, stiff_B = false, stiff_C = false, jac = false, kwargs...) where {
164+
iip, spec}
165+
check_complete(sys, SemilinearODEProblem)
166+
check_compatibility && check_compatible_system(SemilinearODEProblem, sys)
167+
168+
A, B, C = semiquadratic_form = calculate_semiquadratic_form(sys; sparse)
169+
eqs = equations(sys)
170+
dvs = unknowns(sys)
171+
172+
sys = add_semiquadratic_parameters(sys, A, B, C)
173+
if A !== nothing
174+
linear_matrix_param = unwrap(getproperty(sys, LINEAR_MATRIX_PARAM_NAME))
175+
else
176+
linear_matrix_param = nothing
177+
end
178+
if B !== nothing
179+
quadratic_forms = [unwrap(getproperty(sys, get_quadratic_form_name(i)))
180+
for i in 1:length(eqs)]
181+
diffcache_par = unwrap(getproperty(sys, DIFFCACHE_PARAM_NAME))
182+
else
183+
quadratic_forms = diffcache_par = nothing
184+
end
185+
186+
op = to_varmap(op, dvs)
187+
floatT = calculate_float_type(op, typeof(op))
188+
_u0_eltype = something(u0_eltype, floatT)
189+
190+
guess = copy(guesses(sys))
191+
defs = copy(defaults(sys))
192+
if A !== nothing
193+
guess[linear_matrix_param] = fill(NaN, size(A))
194+
defs[linear_matrix_param] = A
195+
end
196+
if B !== nothing
197+
for (par, mat) in zip(quadratic_forms, B)
198+
guess[par] = fill(NaN, size(mat))
199+
defs[par] = mat
200+
end
201+
cachelen = jac ? length(dvs) * length(eqs) : length(dvs)
202+
defs[diffcache_par] = DiffCache(zeros(DiffEqBase.value(_u0_eltype), cachelen))
203+
end
204+
@set! sys.guesses = guess
205+
@set! sys.defaults = defs
206+
207+
f, u0, p = process_SciMLProblem(SemilinearODEFunction{iip, spec}, sys, op;
208+
t = tspan !== nothing ? tspan[1] : tspan, expression, check_compatibility,
209+
semiquadratic_form, sparse, u0_eltype, stiff_A, stiff_B, stiff_C, jac, kwargs...)
210+
211+
kwargs = process_kwargs(sys; expression, callback, kwargs...)
212+
213+
args = (; f, u0, tspan, p)
214+
maybe_codegen_scimlproblem(expression, SplitODEProblem{iip}, args; kwargs...)
215+
end
216+
217+
"""
218+
$(TYPEDSIGNATURES)
219+
220+
Add the necessary parameters for [`SemilinearODEProblem`](@ref) given the matrices
221+
`A`, `B`, `C` returned from [`calculate_semiquadratic_form`](@ref).
222+
"""
223+
function add_semiquadratic_parameters(sys::System, A, B, C)
224+
eqs = equations(sys)
225+
n = length(eqs)
226+
var_to_name = copy(get_var_to_name(sys))
227+
if B !== nothing
228+
for i in eachindex(B)
229+
B[i] === nothing && continue
230+
par = get_quadratic_form_param((n, n), i)
231+
var_to_name[get_quadratic_form_name(i)] = par
232+
sys = with_additional_constant_parameter(sys, par)
233+
end
234+
par = get_diffcache_param(Float64)
235+
var_to_name[DIFFCACHE_PARAM_NAME] = par
236+
sys = with_additional_nonnumeric_parameter(sys, par)
237+
end
238+
if A !== nothing
239+
par = get_linear_matrix_param((n, n))
240+
var_to_name[LINEAR_MATRIX_PARAM_NAME] = par
241+
sys = with_additional_constant_parameter(sys, par)
242+
end
243+
@set! sys.var_to_name = var_to_name
244+
if get_parent(sys) !== nothing
245+
@set! sys.parent = add_semiquadratic_parameters(get_parent(sys), A, B, C)
246+
end
247+
return sys
248+
end
249+
101250
function check_compatible_system(
102251
T::Union{Type{ODEFunction}, Type{ODEProblem}, Type{DAEFunction},
103-
Type{DAEProblem}, Type{SteadyStateProblem}},
252+
Type{DAEProblem}, Type{SteadyStateProblem}, Type{SemilinearODEFunction},
253+
Type{SemilinearODEProblem}},
104254
sys::System)
105255
check_time_dependent(sys, T)
106256
check_not_dde(sys)

0 commit comments

Comments
 (0)