Skip to content

Commit be0a9f4

Browse files
feat: add SemilinearODEFunction and SemilinearODEProblem
1 parent 6af4474 commit be0a9f4

File tree

4 files changed

+630
-1
lines changed

4 files changed

+630
-1
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
2626
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
2727
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
2828
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
29+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
2930
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
3031
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3132
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
@@ -45,6 +46,7 @@ NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
4546
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
4647
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
4748
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
49+
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
4850
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
4951
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
5052
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -115,6 +117,7 @@ DynamicQuantities = "^0.11.2, 0.12, 0.13, 1"
115117
EnumX = "1.0.4"
116118
ExprTools = "0.1.10"
117119
FMI = "0.14"
120+
FillArrays = "1.13.0"
118121
FindFirstFunctions = "1"
119122
ForwardDiff = "0.10.3"
120123
FunctionWrappers = "1.1"
@@ -142,6 +145,7 @@ OrdinaryDiffEq = "6.82.0"
142145
OrdinaryDiffEqCore = "1.15.0"
143146
OrdinaryDiffEqDefault = "1.2"
144147
OrdinaryDiffEqNonlinearSolve = "1.5.0"
148+
PreallocationTools = "0.4.27"
145149
PrecompileTools = "1"
146150
Pyomo = "0.1.0"
147151
REPL = "1"

src/ModelingToolkit.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ const DQ = DynamicQuantities
9999
import DifferentiationInterface as DI
100100
using ADTypes: AutoForwardDiff
101101
import SciMLPublic: @public
102+
import PreallocationTools
103+
import PreallocationTools: DiffCache
104+
import FillArrays
102105

103106
export @derivatives
104107

@@ -256,6 +259,7 @@ export IntervalNonlinearProblem
256259
export OptimizationProblem, constraints
257260
export SteadyStateProblem
258261
export JumpProblem
262+
export SemilinearODEFunction, SemilinearODEProblem
259263
export alias_elimination, flatten
260264
export connect, domain_connect, @connector, Connection, AnalysisPoint, Flow, Stream,
261265
instream

src/problems/odeproblem.jl

Lines changed: 151 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,159 @@ end
100100
maybe_codegen_scimlproblem(expression, SteadyStateProblem{iip}, args; kwargs...)
101101
end
102102

103+
@fallback_iip_specialize function SemilinearODEFunction{iip, specialize}(
104+
sys::System; u0 = nothing, p = nothing, t = nothing,
105+
semiquadratic_form = nothing,
106+
stiff_linear = true, stiff_quadratic = false, stiff_nonlinear = false,
107+
eval_expression = false, eval_module = @__MODULE__,
108+
expression = Val{false}, sparse = false, check_compatibility = true,
109+
jac = false, checkbounds = false, cse = true, initialization_data = nothing,
110+
analytic = nothing, kwargs...) where {iip, specialize}
111+
check_complete(sys, SemilinearODEFunction)
112+
check_compatibility && check_compatible_system(SemilinearODEFunction, sys)
113+
114+
if semiquadratic_form === nothing
115+
semiquadratic_form = calculate_semiquadratic_form(sys; sparse)
116+
sys = add_semiquadratic_parameters(sys, semiquadratic_form...)
117+
end
118+
119+
A, B, C = semiquadratic_form
120+
M = calculate_massmatrix(sys)
121+
_M = concrete_massmatrix(M; sparse, u0)
122+
dvs = unknowns(sys)
123+
124+
f1, f2 = generate_semiquadratic_functions(
125+
sys, A, B, C; stiff_linear, stiff_quadratic, stiff_nonlinear, expression, wrap_gfw = Val{true},
126+
eval_expression, eval_module, kwargs...)
127+
128+
if jac
129+
Cjac = (C === nothing || !stiff_nonlinear) ? nothing : Symbolics.jacobian(C, dvs)
130+
_jac = generate_semiquadratic_jacobian(
131+
sys, A, B, C, Cjac; sparse, expression,
132+
wrap_gfw = Val{true}, eval_expression, eval_module, kwargs...)
133+
_W_sparsity = get_semiquadratic_W_sparsity(
134+
sys, A, B, C, Cjac; stiff_linear, stiff_quadratic, stiff_nonlinear, mm = M)
135+
W_prototype = calculate_W_prototype(_W_sparsity; u0, sparse)
136+
else
137+
_jac = nothing
138+
W_prototype = nothing
139+
end
140+
141+
observedfun = ObservedFunctionCache(
142+
sys; expression, steady_state = false, eval_expression, eval_module, checkbounds, cse)
143+
144+
args = (; f1)
145+
kwargs = (; jac = _jac, jac_prototype = W_prototype)
146+
f1 = maybe_codegen_scimlfn(expression, ODEFunction{iip, specialize}, args; kwargs...)
147+
148+
args = (; f1, f2)
149+
kwargs = (;
150+
sys = sys,
151+
jac = _jac,
152+
mass_matrix = _M,
153+
jac_prototype = W_prototype,
154+
observed = observedfun,
155+
analytic,
156+
initialization_data)
157+
158+
return maybe_codegen_scimlfn(
159+
expression, SplitFunction{iip, specialize}, args; kwargs...)
160+
end
161+
162+
@fallback_iip_specialize function SemilinearODEProblem{iip, spec}(
163+
sys::System, op, tspan; check_compatibility = true, u0_eltype = nothing,
164+
expression = Val{false}, callback = nothing, sparse = false,
165+
stiff_linear = true, stiff_quadratic = false, stiff_nonlinear = false, jac = false, kwargs...) where {
166+
iip, spec}
167+
check_complete(sys, SemilinearODEProblem)
168+
check_compatibility && check_compatible_system(SemilinearODEProblem, sys)
169+
170+
A, B, C = semiquadratic_form = calculate_semiquadratic_form(sys; sparse)
171+
eqs = equations(sys)
172+
dvs = unknowns(sys)
173+
174+
sys = add_semiquadratic_parameters(sys, A, B, C)
175+
if A !== nothing
176+
linear_matrix_param = unwrap(getproperty(sys, LINEAR_MATRIX_PARAM_NAME))
177+
else
178+
linear_matrix_param = nothing
179+
end
180+
if B !== nothing
181+
quadratic_forms = [unwrap(getproperty(sys, get_quadratic_form_name(i)))
182+
for i in 1:length(eqs)]
183+
diffcache_par = unwrap(getproperty(sys, DIFFCACHE_PARAM_NAME))
184+
else
185+
quadratic_forms = diffcache_par = nothing
186+
end
187+
188+
op = to_varmap(op, dvs)
189+
floatT = calculate_float_type(op, typeof(op))
190+
_u0_eltype = something(u0_eltype, floatT)
191+
192+
guess = copy(guesses(sys))
193+
defs = copy(defaults(sys))
194+
if A !== nothing
195+
guess[linear_matrix_param] = fill(NaN, size(A))
196+
defs[linear_matrix_param] = A
197+
end
198+
if B !== nothing
199+
for (par, mat) in zip(quadratic_forms, B)
200+
guess[par] = fill(NaN, size(mat))
201+
defs[par] = mat
202+
end
203+
cachelen = jac ? length(dvs) * length(eqs) : length(dvs)
204+
defs[diffcache_par] = DiffCache(zeros(DiffEqBase.value(_u0_eltype), cachelen))
205+
end
206+
@set! sys.guesses = guess
207+
@set! sys.defaults = defs
208+
209+
f, u0, p = process_SciMLProblem(SemilinearODEFunction{iip, spec}, sys, op;
210+
t = tspan !== nothing ? tspan[1] : tspan, expression, check_compatibility,
211+
semiquadratic_form, sparse, u0_eltype, stiff_linear, stiff_quadratic, stiff_nonlinear, jac, kwargs...)
212+
213+
kwargs = process_kwargs(sys; expression, callback, kwargs...)
214+
215+
args = (; f, u0, tspan, p)
216+
maybe_codegen_scimlproblem(expression, SplitODEProblem{iip}, args; kwargs...)
217+
end
218+
219+
"""
220+
$(TYPEDSIGNATURES)
221+
222+
Add the necessary parameters for [`SemilinearODEProblem`](@ref) given the matrices
223+
`A`, `B`, `C` returned from [`calculate_semiquadratic_form`](@ref).
224+
"""
225+
function add_semiquadratic_parameters(sys::System, A, B, C)
226+
eqs = equations(sys)
227+
n = length(eqs)
228+
var_to_name = copy(get_var_to_name(sys))
229+
if B !== nothing
230+
for i in eachindex(B)
231+
B[i] === nothing && continue
232+
par = get_quadratic_form_param((n, n), i)
233+
var_to_name[get_quadratic_form_name(i)] = par
234+
sys = with_additional_constant_parameter(sys, par)
235+
end
236+
par = get_diffcache_param(Float64)
237+
var_to_name[DIFFCACHE_PARAM_NAME] = par
238+
sys = with_additional_nonnumeric_parameter(sys, par)
239+
end
240+
if A !== nothing
241+
par = get_linear_matrix_param((n, n))
242+
var_to_name[LINEAR_MATRIX_PARAM_NAME] = par
243+
sys = with_additional_constant_parameter(sys, par)
244+
end
245+
@set! sys.var_to_name = var_to_name
246+
if get_parent(sys) !== nothing
247+
@set! sys.parent = add_semiquadratic_parameters(get_parent(sys), A, B, C)
248+
end
249+
return sys
250+
end
251+
103252
function check_compatible_system(
104253
T::Union{Type{ODEFunction}, Type{ODEProblem}, Type{DAEFunction},
105-
Type{DAEProblem}, Type{SteadyStateProblem}},
254+
Type{DAEProblem}, Type{SteadyStateProblem}, Type{SemilinearODEFunction},
255+
Type{SemilinearODEProblem}},
106256
sys::System)
107257
check_time_dependent(sys, T)
108258
check_not_dde(sys)

0 commit comments

Comments
 (0)