|
98 | 98 | maybe_codegen_scimlproblem(expression, SteadyStateProblem{iip}, args; kwargs...)
|
99 | 99 | end
|
100 | 100 |
|
| 101 | +@fallback_iip_specialize function SemilinearODEFunction{iip, specialize}( |
| 102 | + sys::System; u0 = nothing, p = nothing, t = nothing, |
| 103 | + semiquadratic_form = nothing, |
| 104 | + stiff_A = true, stiff_B = false, stiff_C = false, |
| 105 | + eval_expression = false, eval_module = @__MODULE__, |
| 106 | + expression = Val{false}, sparse = false, check_compatibility = true, |
| 107 | + jac = false, checkbounds = false, cse = true, initialization_data = nothing, |
| 108 | + analytic = nothing, kwargs...) where {iip, specialize} |
| 109 | + check_complete(sys, SemilinearODEFunction) |
| 110 | + check_compatibility && check_compatible_system(SemilinearODEFunction, sys) |
| 111 | + |
| 112 | + if semiquadratic_form === nothing |
| 113 | + semiquadratic_form = calculate_semiquadratic_form(sys; sparse) |
| 114 | + sys = add_semiquadratic_parameters(sys, semiquadratic_form...) |
| 115 | + end |
| 116 | + |
| 117 | + A, B, C = semiquadratic_form |
| 118 | + M = calculate_massmatrix(sys) |
| 119 | + _M = concrete_massmatrix(M; sparse, u0) |
| 120 | + dvs = unknowns(sys) |
| 121 | + |
| 122 | + f1, f2 = generate_semiquadratic_functions( |
| 123 | + sys, A, B, C; stiff_A, stiff_B, stiff_C, expression, wrap_gfw = Val{true}, |
| 124 | + eval_expression, eval_module, kwargs...) |
| 125 | + |
| 126 | + if jac |
| 127 | + Cjac = (C === nothing || !stiff_C) ? nothing : Symbolics.jacobian(C, dvs) |
| 128 | + _jac = generate_semiquadratic_jacobian( |
| 129 | + sys, A, B, C, Cjac; sparse, expression, |
| 130 | + wrap_gfw = Val{true}, eval_expression, eval_module, kwargs...) |
| 131 | + _W_sparsity = get_semiquadratic_W_sparsity( |
| 132 | + sys, A, B, C, Cjac; stiff_A, stiff_B, stiff_C, mm = M) |
| 133 | + W_prototype = calculate_W_prototype(_W_sparsity; u0, sparse) |
| 134 | + else |
| 135 | + _jac = nothing |
| 136 | + W_prototype = nothing |
| 137 | + end |
| 138 | + |
| 139 | + observedfun = ObservedFunctionCache( |
| 140 | + sys; expression, steady_state = false, eval_expression, eval_module, checkbounds, cse) |
| 141 | + |
| 142 | + args = (; f1) |
| 143 | + kwargs = (; jac = _jac, jac_prototype = W_prototype) |
| 144 | + f1 = maybe_codegen_scimlfn(expression, ODEFunction{iip, specialize}, args; kwargs...) |
| 145 | + |
| 146 | + args = (; f1, f2) |
| 147 | + kwargs = (; |
| 148 | + sys = sys, |
| 149 | + jac = _jac, |
| 150 | + mass_matrix = _M, |
| 151 | + jac_prototype = W_prototype, |
| 152 | + observed = observedfun, |
| 153 | + analytic, |
| 154 | + initialization_data) |
| 155 | + |
| 156 | + return maybe_codegen_scimlfn( |
| 157 | + expression, SplitFunction{iip, specialize}, args; kwargs...) |
| 158 | +end |
| 159 | + |
| 160 | +@fallback_iip_specialize function SemilinearODEProblem{iip, spec}( |
| 161 | + sys::System, op, tspan; check_compatibility = true, u0_eltype = nothing, |
| 162 | + expression = Val{false}, callback = nothing, sparse = false, |
| 163 | + stiff_A = true, stiff_B = false, stiff_C = false, jac = false, kwargs...) where { |
| 164 | + iip, spec} |
| 165 | + check_complete(sys, SemilinearODEProblem) |
| 166 | + check_compatibility && check_compatible_system(SemilinearODEProblem, sys) |
| 167 | + |
| 168 | + A, B, C = semiquadratic_form = calculate_semiquadratic_form(sys; sparse) |
| 169 | + eqs = equations(sys) |
| 170 | + dvs = unknowns(sys) |
| 171 | + |
| 172 | + sys = add_semiquadratic_parameters(sys, A, B, C) |
| 173 | + if A !== nothing |
| 174 | + linear_matrix_param = unwrap(getproperty(sys, LINEAR_MATRIX_PARAM_NAME)) |
| 175 | + else |
| 176 | + linear_matrix_param = nothing |
| 177 | + end |
| 178 | + if B !== nothing |
| 179 | + quadratic_forms = [unwrap(getproperty(sys, get_quadratic_form_name(i))) |
| 180 | + for i in 1:length(eqs)] |
| 181 | + diffcache_par = unwrap(getproperty(sys, DIFFCACHE_PARAM_NAME)) |
| 182 | + else |
| 183 | + quadratic_forms = diffcache_par = nothing |
| 184 | + end |
| 185 | + |
| 186 | + op = to_varmap(op, dvs) |
| 187 | + floatT = calculate_float_type(op, typeof(op)) |
| 188 | + _u0_eltype = something(u0_eltype, floatT) |
| 189 | + |
| 190 | + guess = copy(guesses(sys)) |
| 191 | + defs = copy(defaults(sys)) |
| 192 | + if A !== nothing |
| 193 | + guess[linear_matrix_param] = fill(NaN, size(A)) |
| 194 | + defs[linear_matrix_param] = A |
| 195 | + end |
| 196 | + if B !== nothing |
| 197 | + for (par, mat) in zip(quadratic_forms, B) |
| 198 | + guess[par] = fill(NaN, size(mat)) |
| 199 | + defs[par] = mat |
| 200 | + end |
| 201 | + cachelen = jac ? length(dvs) * length(eqs) : length(dvs) |
| 202 | + defs[diffcache_par] = DiffCache(zeros(DiffEqBase.value(_u0_eltype), cachelen)) |
| 203 | + end |
| 204 | + @set! sys.guesses = guess |
| 205 | + @set! sys.defaults = defs |
| 206 | + |
| 207 | + f, u0, p = process_SciMLProblem(SemilinearODEFunction{iip, spec}, sys, op; |
| 208 | + t = tspan !== nothing ? tspan[1] : tspan, expression, check_compatibility, |
| 209 | + semiquadratic_form, sparse, u0_eltype, stiff_A, stiff_B, stiff_C, jac, kwargs...) |
| 210 | + |
| 211 | + kwargs = process_kwargs(sys; expression, callback, kwargs...) |
| 212 | + |
| 213 | + args = (; f, u0, tspan, p) |
| 214 | + maybe_codegen_scimlproblem(expression, SplitODEProblem{iip}, args; kwargs...) |
| 215 | +end |
| 216 | + |
| 217 | +""" |
| 218 | + $(TYPEDSIGNATURES) |
| 219 | +
|
| 220 | +Add the necessary parameters for [`SemilinearODEProblem`](@ref) given the matrices |
| 221 | +`A`, `B`, `C` returned from [`calculate_semiquadratic_form`](@ref). |
| 222 | +""" |
| 223 | +function add_semiquadratic_parameters(sys::System, A, B, C) |
| 224 | + eqs = equations(sys) |
| 225 | + n = length(eqs) |
| 226 | + var_to_name = copy(get_var_to_name(sys)) |
| 227 | + if B !== nothing |
| 228 | + for i in eachindex(B) |
| 229 | + B[i] === nothing && continue |
| 230 | + par = get_quadratic_form_param((n, n), i) |
| 231 | + var_to_name[get_quadratic_form_name(i)] = par |
| 232 | + sys = with_additional_constant_parameter(sys, par) |
| 233 | + end |
| 234 | + par = get_diffcache_param(Float64) |
| 235 | + var_to_name[DIFFCACHE_PARAM_NAME] = par |
| 236 | + sys = with_additional_nonnumeric_parameter(sys, par) |
| 237 | + end |
| 238 | + if A !== nothing |
| 239 | + par = get_linear_matrix_param((n, n)) |
| 240 | + var_to_name[LINEAR_MATRIX_PARAM_NAME] = par |
| 241 | + sys = with_additional_constant_parameter(sys, par) |
| 242 | + end |
| 243 | + @set! sys.var_to_name = var_to_name |
| 244 | + if get_parent(sys) !== nothing |
| 245 | + @set! sys.parent = add_semiquadratic_parameters(get_parent(sys), A, B, C) |
| 246 | + end |
| 247 | + return sys |
| 248 | +end |
| 249 | + |
101 | 250 | function check_compatible_system(
|
102 | 251 | T::Union{Type{ODEFunction}, Type{ODEProblem}, Type{DAEFunction},
|
103 |
| - Type{DAEProblem}, Type{SteadyStateProblem}}, |
| 252 | + Type{DAEProblem}, Type{SteadyStateProblem}, Type{SemilinearODEFunction}, |
| 253 | + Type{SemilinearODEProblem}}, |
104 | 254 | sys::System)
|
105 | 255 | check_time_dependent(sys, T)
|
106 | 256 | check_not_dde(sys)
|
|
0 commit comments