@@ -68,17 +68,26 @@ function Base.showerror(io::IO, e::OverrideInitMissingAlgorithm)
68
68
" OverrideInit specified but no NonlinearSolve.jl algorithm provided. Provide an algorithm via the `nlsolve_alg` keyword argument to `get_initial_values`." )
69
69
end
70
70
71
+ struct OverrideInitNoTolerance <: Exception
72
+ tolerance:: Symbol
73
+ end
74
+
75
+ function Base. showerror (io:: IO , e:: OverrideInitNoTolerance )
76
+ print (io,
77
+ " Tolerances were not provided to `OverrideInit`. `$(e. tolerance) ` must be provided as a keyword argument to `get_initial_values` or as a keyword argument to the `OverrideInit` constructor." )
78
+ end
79
+
71
80
"""
72
- Utility function to evaluate the RHS of the ODE , using the integrator's `tmp_cache` if
81
+ Utility function to evaluate the RHS, using the integrator's `tmp_cache` if
73
82
it is in-place or simply calling the function if not.
74
83
"""
75
- function _evaluate_f_ode (integrator, f, isinplace:: Val{true} , args... )
84
+ function _evaluate_f (integrator, f, isinplace:: Val{true} , args... )
76
85
tmp = first (get_tmp_cache (integrator))
77
86
f (tmp, args... )
78
87
return tmp
79
88
end
80
89
81
- function _evaluate_f_ode (integrator, f, isinplace:: Val{false} , args... )
90
+ function _evaluate_f (integrator, f, isinplace:: Val{false} , args... )
82
91
return f (args... )
83
92
end
84
93
@@ -98,53 +107,49 @@ _vec(v::AbstractVector) = v
98
107
99
108
Check if the algebraic constraints are satisfied, and error if they aren't. Returns
100
109
the `u0` and `p` as-is, and is always successful if it returns. Valid only for
101
- `ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument.
110
+ `AbstractDEProblem` and `AbstractDAEProblem`. Requires a `DEIntegrator` as its second argument.
111
+
112
+ Keyword arguments:
113
+ - `abstol`: The absolute value below which the norm of the residual of algebraic equations
114
+ should lie. The norm function used is `integrator.opts.internalnorm` if present, and
115
+ `LinearAlgebra.norm` if not.
102
116
"""
103
- function get_initial_values (prob:: AbstractODEProblem , integrator, f, alg:: CheckInit ,
104
- isinplace:: Union{Val{true}, Val{false}} ; kwargs... )
117
+ function get_initial_values (
118
+ prob:: AbstractDEProblem , integrator:: DEIntegrator , f, alg:: CheckInit ,
119
+ isinplace:: Union{Val{true}, Val{false}} ; abstol, kwargs... )
105
120
u0 = state_values (integrator)
106
121
p = parameter_values (integrator)
107
122
t = current_time (integrator)
108
123
M = f. mass_matrix
109
124
110
125
algebraic_vars = [all (iszero, x) for x in eachcol (M)]
111
126
algebraic_eqs = [all (iszero, x) for x in eachrow (M)]
112
- (iszero (algebraic_vars) || iszero (algebraic_eqs)) && return
127
+ (iszero (algebraic_vars) || iszero (algebraic_eqs)) && return u0, p, true
113
128
update_coefficients! (M, u0, p, t)
114
- tmp = _evaluate_f_ode (integrator, f, isinplace, u0, p, t)
129
+ tmp = _evaluate_f (integrator, f, isinplace, u0, p, t)
115
130
tmp .= ArrayInterface. restructure (tmp, algebraic_eqs .* _vec (tmp))
116
131
117
- normresid = integrator. opts. internalnorm (tmp, t)
118
- if normresid > integrator. opts. abstol
119
- throw (CheckInitFailureError (normresid, integrator. opts. abstol))
132
+ normresid = isdefined (integrator. opts, :internalnorm ) ?
133
+ integrator. opts. internalnorm (tmp, t) : norm (tmp)
134
+ if normresid > abstol
135
+ throw (CheckInitFailureError (normresid, abstol))
120
136
end
121
137
return u0, p, true
122
138
end
123
139
124
- """
125
- Utility function to evaluate the RHS of the DAE, using the integrator's `tmp_cache` if
126
- it is in-place or simply calling the function if not.
127
- """
128
- function _evaluate_f_dae (integrator, f, isinplace:: Val{true} , args... )
129
- tmp = get_tmp_cache (integrator)[2 ]
130
- f (tmp, args... )
131
- return tmp
132
- end
133
-
134
- function _evaluate_f_dae (integrator, f, isinplace:: Val{false} , args... )
135
- return f (args... )
136
- end
137
-
138
- function get_initial_values (prob:: AbstractDAEProblem , integrator, f, alg:: CheckInit ,
139
- isinplace:: Union{Val{true}, Val{false}} ; kwargs... )
140
+ function get_initial_values (
141
+ prob:: AbstractDAEProblem , integrator:: DEIntegrator , f, alg:: CheckInit ,
142
+ isinplace:: Union{Val{true}, Val{false}} ; abstol, kwargs... )
140
143
u0 = state_values (integrator)
141
144
p = parameter_values (integrator)
142
145
t = current_time (integrator)
143
146
144
- resid = _evaluate_f_dae (integrator, f, isinplace, integrator. du, u0, p, t)
145
- normresid = integrator. opts. internalnorm (resid, t)
146
- if normresid > integrator. opts. abstol
147
- throw (CheckInitFailureError (normresid, integrator. opts. abstol))
147
+ resid = _evaluate_f (integrator, f, isinplace, integrator. du, u0, p, t)
148
+ normresid = isdefined (integrator. opts, :internalnorm ) ?
149
+ integrator. opts. internalnorm (resid, t) : norm (resid)
150
+
151
+ if normresid > abstol
152
+ throw (CheckInitFailureError (normresid, abstol))
148
153
end
149
154
return u0, p, true
150
155
end
@@ -155,12 +160,19 @@ end
155
160
Solve a `NonlinearProblem`/`NonlinearLeastSquaresProblem` to obtain the initial `u0` and
156
161
`p`. Requires that `f` have the field `initialization_data` which is an `OverrideInitData`.
157
162
If the field is absent or the value is `nothing`, return `u0` and `p` successfully as-is.
158
- The NonlinearSolve.jl algorithm to use must be specified through the `nlsolve_alg` keyword
159
- argument, failing which this function will throw an error. The success value returned
160
- depends on the success of the nonlinear solve.
163
+
164
+ The success value returned depends on the success of the nonlinear solve.
165
+
166
+ Keyword arguments:
167
+ - `nlsolve_alg`: The NonlinearSolve.jl algorithm to use. If not provided, this function will
168
+ throw an error.
169
+ - `abstol`, `reltol`: The `abstol` (`reltol`) to use for the nonlinear solve. The value
170
+ provided to the `OverrideInit` constructor takes priority over this keyword argument.
171
+ If the former is `nothing`, this keyword argument will be used. If it is also not provided,
172
+ an error will be thrown.
161
173
"""
162
174
function get_initial_values (prob, valp, f, alg:: OverrideInit ,
163
- isinplace :: Union{Val{true}, Val{false}} ; nlsolve_alg = nothing , kwargs... )
175
+ iip :: Union{Val{true}, Val{false}} ; nlsolve_alg = nothing , abstol = nothing , reltol = nothing , kwargs... )
164
176
u0 = state_values (valp)
165
177
p = parameter_values (valp)
166
178
@@ -171,15 +183,30 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
171
183
initdata:: OverrideInitData = f. initialization_data
172
184
initprob = initdata. initializeprob
173
185
174
- if nlsolve_alg === nothing
186
+ nlsolve_alg = something (nlsolve_alg, alg. nlsolve, Some (nothing ))
187
+ if nlsolve_alg === nothing && state_values (initprob) != = nothing
175
188
throw (OverrideInitMissingAlgorithm ())
176
189
end
177
190
178
191
if initdata. update_initializeprob! != = nothing
179
192
initdata. update_initializeprob! (initprob, valp)
180
193
end
181
194
182
- nlsol = solve (initprob, nlsolve_alg)
195
+ if alg. abstol != = nothing
196
+ _abstol = alg. abstol
197
+ elseif abstol != = nothing
198
+ _abstol = abstol
199
+ else
200
+ throw (OverrideInitNoTolerance (:abstol ))
201
+ end
202
+ if alg. reltol != = nothing
203
+ _reltol = alg. reltol
204
+ elseif reltol != = nothing
205
+ _reltol = reltol
206
+ else
207
+ throw (OverrideInitNoTolerance (:reltol ))
208
+ end
209
+ nlsol = solve (initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol)
183
210
184
211
u0 = initdata. initializeprobmap (nlsol)
185
212
if initdata. initializeprobpmap != = nothing
0 commit comments