@@ -100,9 +100,159 @@ end
100
100
maybe_codegen_scimlproblem (expression, SteadyStateProblem{iip}, args; kwargs... )
101
101
end
102
102
103
+ @fallback_iip_specialize function SemilinearODEFunction {iip, specialize} (
104
+ sys:: System ; u0 = nothing , p = nothing , t = nothing ,
105
+ semiquadratic_form = nothing ,
106
+ stiff_linear = true , stiff_quadratic = false , stiff_nonlinear = false ,
107
+ eval_expression = false , eval_module = @__MODULE__ ,
108
+ expression = Val{false }, sparse = false , check_compatibility = true ,
109
+ jac = false , checkbounds = false , cse = true , initialization_data = nothing ,
110
+ analytic = nothing , kwargs... ) where {iip, specialize}
111
+ check_complete (sys, SemilinearODEFunction)
112
+ check_compatibility && check_compatible_system (SemilinearODEFunction, sys)
113
+
114
+ if semiquadratic_form === nothing
115
+ semiquadratic_form = calculate_semiquadratic_form (sys; sparse)
116
+ sys = add_semiquadratic_parameters (sys, semiquadratic_form... )
117
+ end
118
+
119
+ A, B, C = semiquadratic_form
120
+ M = calculate_massmatrix (sys)
121
+ _M = concrete_massmatrix (M; sparse, u0)
122
+ dvs = unknowns (sys)
123
+
124
+ f1, f2 = generate_semiquadratic_functions (
125
+ sys, A, B, C; stiff_linear, stiff_quadratic, stiff_nonlinear, expression, wrap_gfw = Val{true },
126
+ eval_expression, eval_module, kwargs... )
127
+
128
+ if jac
129
+ Cjac = (C === nothing || ! stiff_nonlinear) ? nothing : Symbolics. jacobian (C, dvs)
130
+ _jac = generate_semiquadratic_jacobian (
131
+ sys, A, B, C, Cjac; sparse, expression,
132
+ wrap_gfw = Val{true }, eval_expression, eval_module, kwargs... )
133
+ _W_sparsity = get_semiquadratic_W_sparsity (
134
+ sys, A, B, C, Cjac; stiff_linear, stiff_quadratic, stiff_nonlinear, mm = M)
135
+ W_prototype = calculate_W_prototype (_W_sparsity; u0, sparse)
136
+ else
137
+ _jac = nothing
138
+ W_prototype = nothing
139
+ end
140
+
141
+ observedfun = ObservedFunctionCache (
142
+ sys; expression, steady_state = false , eval_expression, eval_module, checkbounds, cse)
143
+
144
+ args = (; f1)
145
+ kwargs = (; jac = _jac, jac_prototype = W_prototype)
146
+ f1 = maybe_codegen_scimlfn (expression, ODEFunction{iip, specialize}, args; kwargs... )
147
+
148
+ args = (; f1, f2)
149
+ kwargs = (;
150
+ sys = sys,
151
+ jac = _jac,
152
+ mass_matrix = _M,
153
+ jac_prototype = W_prototype,
154
+ observed = observedfun,
155
+ analytic,
156
+ initialization_data)
157
+
158
+ return maybe_codegen_scimlfn (
159
+ expression, SplitFunction{iip, specialize}, args; kwargs... )
160
+ end
161
+
162
+ @fallback_iip_specialize function SemilinearODEProblem {iip, spec} (
163
+ sys:: System , op, tspan; check_compatibility = true , u0_eltype = nothing ,
164
+ expression = Val{false }, callback = nothing , sparse = false ,
165
+ stiff_linear = true , stiff_quadratic = false , stiff_nonlinear = false , jac = false , kwargs... ) where {
166
+ iip, spec}
167
+ check_complete (sys, SemilinearODEProblem)
168
+ check_compatibility && check_compatible_system (SemilinearODEProblem, sys)
169
+
170
+ A, B, C = semiquadratic_form = calculate_semiquadratic_form (sys; sparse)
171
+ eqs = equations (sys)
172
+ dvs = unknowns (sys)
173
+
174
+ sys = add_semiquadratic_parameters (sys, A, B, C)
175
+ if A != = nothing
176
+ linear_matrix_param = unwrap (getproperty (sys, LINEAR_MATRIX_PARAM_NAME))
177
+ else
178
+ linear_matrix_param = nothing
179
+ end
180
+ if B != = nothing
181
+ quadratic_forms = [unwrap (getproperty (sys, get_quadratic_form_name (i)))
182
+ for i in 1 : length (eqs)]
183
+ diffcache_par = unwrap (getproperty (sys, DIFFCACHE_PARAM_NAME))
184
+ else
185
+ quadratic_forms = diffcache_par = nothing
186
+ end
187
+
188
+ op = to_varmap (op, dvs)
189
+ floatT = calculate_float_type (op, typeof (op))
190
+ _u0_eltype = something (u0_eltype, floatT)
191
+
192
+ guess = copy (guesses (sys))
193
+ defs = copy (defaults (sys))
194
+ if A != = nothing
195
+ guess[linear_matrix_param] = fill (NaN , size (A))
196
+ defs[linear_matrix_param] = A
197
+ end
198
+ if B != = nothing
199
+ for (par, mat) in zip (quadratic_forms, B)
200
+ guess[par] = fill (NaN , size (mat))
201
+ defs[par] = mat
202
+ end
203
+ cachelen = jac ? length (dvs) * length (eqs) : length (dvs)
204
+ defs[diffcache_par] = DiffCache (zeros (DiffEqBase. value (_u0_eltype), cachelen))
205
+ end
206
+ @set! sys. guesses = guess
207
+ @set! sys. defaults = defs
208
+
209
+ f, u0, p = process_SciMLProblem (SemilinearODEFunction{iip, spec}, sys, op;
210
+ t = tspan != = nothing ? tspan[1 ] : tspan, expression, check_compatibility,
211
+ semiquadratic_form, sparse, u0_eltype, stiff_linear, stiff_quadratic, stiff_nonlinear, jac, kwargs... )
212
+
213
+ kwargs = process_kwargs (sys; expression, callback, kwargs... )
214
+
215
+ args = (; f, u0, tspan, p)
216
+ maybe_codegen_scimlproblem (expression, SplitODEProblem{iip}, args; kwargs... )
217
+ end
218
+
219
+ """
220
+ $(TYPEDSIGNATURES)
221
+
222
+ Add the necessary parameters for [`SemilinearODEProblem`](@ref) given the matrices
223
+ `A`, `B`, `C` returned from [`calculate_semiquadratic_form`](@ref).
224
+ """
225
+ function add_semiquadratic_parameters (sys:: System , A, B, C)
226
+ eqs = equations (sys)
227
+ n = length (eqs)
228
+ var_to_name = copy (get_var_to_name (sys))
229
+ if B != = nothing
230
+ for i in eachindex (B)
231
+ B[i] === nothing && continue
232
+ par = get_quadratic_form_param ((n, n), i)
233
+ var_to_name[get_quadratic_form_name (i)] = par
234
+ sys = with_additional_constant_parameter (sys, par)
235
+ end
236
+ par = get_diffcache_param (Float64)
237
+ var_to_name[DIFFCACHE_PARAM_NAME] = par
238
+ sys = with_additional_nonnumeric_parameter (sys, par)
239
+ end
240
+ if A != = nothing
241
+ par = get_linear_matrix_param ((n, n))
242
+ var_to_name[LINEAR_MATRIX_PARAM_NAME] = par
243
+ sys = with_additional_constant_parameter (sys, par)
244
+ end
245
+ @set! sys. var_to_name = var_to_name
246
+ if get_parent (sys) != = nothing
247
+ @set! sys. parent = add_semiquadratic_parameters (get_parent (sys), A, B, C)
248
+ end
249
+ return sys
250
+ end
251
+
103
252
function check_compatible_system (
104
253
T:: Union {Type{ODEFunction}, Type{ODEProblem}, Type{DAEFunction},
105
- Type{DAEProblem}, Type{SteadyStateProblem}},
254
+ Type{DAEProblem}, Type{SteadyStateProblem}, Type{SemilinearODEFunction},
255
+ Type{SemilinearODEProblem}},
106
256
sys:: System )
107
257
check_time_dependent (sys, T)
108
258
check_not_dde (sys)
0 commit comments