@@ -25,6 +25,7 @@ The `simplified_sys` has undergone [`structural_simplify`](@ref) and had any occ
25
25
- `simplify`: Apply simplification in tearing.
26
26
- `initialize`: If true, a check is performed to ensure that the operating point is consistent (satisfies algebraic equations). If the op is not consistent, initialization is performed.
27
27
- `initialization_solver_alg`: A NonlinearSolve algorithm to use for solving for a feasible set of state and algebraic variables that satisfies the specified operating point.
28
+ - `autodiff`: An `ADType` supported by DifferentiationInterface.jl to use for calculating the necessary jacobians. Defaults to using `AutoForwardDiff()`
28
29
- `kwargs`: Are passed on to `find_solvables!`
29
30
30
31
See also [`linearize`](@ref) which provides a higher-level interface.
@@ -39,6 +40,7 @@ function linearization_function(sys::AbstractSystem, inputs,
39
40
p = DiffEqBase. NullParameters (),
40
41
zero_dummy_der = false ,
41
42
initialization_solver_alg = TrustRegion (),
43
+ autodiff = AutoForwardDiff (),
42
44
eval_expression = false , eval_module = @__MODULE__ ,
43
45
warn_initialize_determined = true ,
44
46
guesses = Dict (),
@@ -82,13 +84,104 @@ function linearization_function(sys::AbstractSystem, inputs,
82
84
initialization_kwargs = (;
83
85
abstol = initialization_abstol, reltol = initialization_reltol,
84
86
nlsolve_alg = initialization_solver_alg)
87
+
88
+ p = parameter_values (prob)
89
+ t0 = current_time (prob)
90
+ inputvals = [p[idx] for idx in input_idxs]
91
+
92
+ hp_fun = let fun = h, setter = setp_oop (sys, input_idxs)
93
+ function hpf (du, input, u, p, t)
94
+ p = setter (p, input)
95
+ fun (du, u, p, t)
96
+ return du
97
+ end
98
+ end
99
+ if u0 === nothing
100
+ uf_jac = h_jac = pf_jac = nothing
101
+ T = p isa MTKParameters ? eltype (p. tunable) : eltype (p)
102
+ hp_jac = PreparedJacobian {true} (
103
+ hp_fun, zeros (T, size (outputs)), autodiff, inputvals,
104
+ DI. Constant (prob. u0), DI. Constant (p), DI. Constant (t0))
105
+ else
106
+ uf_fun = let fun = prob. f
107
+ function uff (du, u, p, t)
108
+ SciMLBase. UJacobianWrapper (fun, t, p)(du, u)
109
+ end
110
+ end
111
+ uf_jac = PreparedJacobian {true} (
112
+ uf_fun, similar (prob. u0), autodiff, prob. u0, DI. Constant (p), DI. Constant (t0))
113
+ # observed function is a `GeneratedFunctionWrapper` with iip component
114
+ h_jac = PreparedJacobian {true} (h, similar (prob. u0, size (outputs)), autodiff,
115
+ prob. u0, DI. Constant (p), DI. Constant (t0))
116
+ pf_fun = let fun = prob. f, setter = setp_oop (sys, input_idxs)
117
+ function pff (du, input, u, p, t)
118
+ p = setter (p, input)
119
+ SciMLBase. ParamJacobianWrapper (fun, t, u)(du, p)
120
+ end
121
+ end
122
+ pf_jac = PreparedJacobian {true} (pf_fun, similar (prob. u0), autodiff, inputvals,
123
+ DI. Constant (prob. u0), DI. Constant (p), DI. Constant (t0))
124
+ hp_jac = PreparedJacobian {true} (
125
+ hp_fun, similar (prob. u0, size (outputs)), autodiff, inputvals,
126
+ DI. Constant (prob. u0), DI. Constant (p), DI. Constant (t0))
127
+ end
128
+
85
129
lin_fun = LinearizationFunction (
86
130
diff_idxs, alge_idxs, input_idxs, length (unknowns (sys)),
87
- prob, h, u0 === nothing ? nothing : similar (u0),
88
- ForwardDiff . Chunk (input_idxs) , initializealg, initialization_kwargs)
131
+ prob, h, u0 === nothing ? nothing : similar (u0), uf_jac, h_jac, pf_jac,
132
+ hp_jac , initializealg, initialization_kwargs)
89
133
return lin_fun, sys
90
134
end
91
135
136
+ """
137
+ $(TYPEDEF)
138
+
139
+ Callable struct which stores a function and its prepared `DI.jacobian`. Calling with the
140
+ appropriate arguments for DI returns the jacobian.
141
+
142
+ # Fields
143
+
144
+ $(TYPEDFIELDS)
145
+ """
146
+ struct PreparedJacobian{iip, P, F, B, A}
147
+ """
148
+ The preparation object.
149
+ """
150
+ prep:: P
151
+ """
152
+ The function whose jacobian is calculated.
153
+ """
154
+ f:: F
155
+ """
156
+ Buffer for in-place functions.
157
+ """
158
+ buf:: B
159
+ """
160
+ ADType to use for differentiation.
161
+ """
162
+ autodiff:: A
163
+ end
164
+
165
+ function PreparedJacobian {true} (f, buf, autodiff, args... )
166
+ prep = DI. prepare_jacobian (f, buf, autodiff, args... )
167
+ return PreparedJacobian {true, typeof(prep), typeof(f), typeof(buf), typeof(autodiff)} (
168
+ prep, f, buf, autodiff)
169
+ end
170
+
171
+ function PreparedJacobian {false} (f, autodiff, args... )
172
+ prep = DI. prepare_jacobian (f, autodiff, args... )
173
+ return PreparedJacobian {true, typeof(prep), typeof(f), Nothing, typeof(autodiff)} (
174
+ prep, f, nothing )
175
+ end
176
+
177
+ function (pj:: PreparedJacobian{true} )(args... )
178
+ DI. jacobian (pj. f, pj. buf, pj. prep, pj. autodiff, args... )
179
+ end
180
+
181
+ function (pj:: PreparedJacobian{false} )(args... )
182
+ DI. jacobian (pj. f, pj. prep, pj. autodiff, args... )
183
+ end
184
+
92
185
"""
93
186
$(TYPEDEF)
94
187
@@ -100,7 +193,7 @@ $(TYPEDFIELDS)
100
193
"""
101
194
struct LinearizationFunction{
102
195
DI <: AbstractVector{Int} , AI <: AbstractVector{Int} , II, P <: ODEProblem ,
103
- H, C, Ch , IA <: SciMLBase.DAEInitializationAlgorithm , IK}
196
+ H, C, J1, J2, J3, J4 , IA <: SciMLBase.DAEInitializationAlgorithm , IK}
104
197
"""
105
198
The indexes of differential equations in the linearized system.
106
199
"""
@@ -130,11 +223,22 @@ struct LinearizationFunction{
130
223
Any required cache buffers.
131
224
"""
132
225
caches:: C
133
- # TODO : Use DI?
134
226
"""
135
- A `ForwardDiff.Chunk` for taking the jacobian with respect to the inputs.
227
+ `PreparedJacobian` for calculating jacobian of `prob.f` w.r.t. `u`
228
+ """
229
+ uf_jac:: J1
230
+ """
231
+ `PreparedJacobian` for calculating jacobian of `h` w.r.t. `u`
136
232
"""
137
- chunk:: Ch
233
+ h_jac:: J2
234
+ """
235
+ `PreparedJacobian` for calculating jacobian of `prob.f` w.r.t. `p`
236
+ """
237
+ pf_jac:: J3
238
+ """
239
+ `PreparedJacobian` for calculating jacobian of `h` w.r.t. `p`
240
+ """
241
+ hp_jac:: J4
138
242
"""
139
243
The initialization algorithm to use.
140
244
"""
@@ -188,25 +292,18 @@ function (linfun::LinearizationFunction)(u, p, t)
188
292
if ! success
189
293
error (" Initialization algorithm $(linfun. initializealg) failed with `u = $u ` and `p = $p `." )
190
294
end
191
- uf = SciMLBase. UJacobianWrapper (fun, t, p)
192
- fg_xz = ForwardDiff. jacobian (uf, u)
193
- h_xz = ForwardDiff. jacobian (
194
- let p = p, t = t, h = linfun. h
195
- xz -> h (xz, p, t)
196
- end , u)
197
- pf = SciMLBase. ParamJacobianWrapper (fun, t, u)
198
- fg_u = jacobian_wrt_vars (pf, p, linfun. input_idxs, linfun. chunk)
295
+ fg_xz = linfun. uf_jac (u, DI. Constant (p), DI. Constant (t))
296
+ h_xz = linfun. h_jac (u, DI. Constant (p), DI. Constant (t))
297
+ fg_u = linfun. pf_jac ([p[idx] for idx in linfun. input_idxs],
298
+ DI. Constant (u), DI. Constant (p), DI. Constant (t))
199
299
else
200
300
linfun. num_states == 0 ||
201
301
error (" Number of unknown variables (0) does not match the number of input unknowns ($(length (u)) )" )
202
302
fg_xz = zeros (0 , 0 )
203
303
h_xz = fg_u = zeros (0 , length (linfun. input_idxs))
204
304
end
205
- hp = let u = u, t = t, h = linfun. h
206
- _hp (p) = h (u, p, t)
207
- _hp
208
- end
209
- h_u = jacobian_wrt_vars (hp, p, linfun. input_idxs, linfun. chunk)
305
+ h_u = linfun. hp_jac ([p[idx] for idx in linfun. input_idxs],
306
+ DI. Constant (u), DI. Constant (p), DI. Constant (t))
210
307
(f_x = fg_xz[linfun. diff_idxs, linfun. diff_idxs],
211
308
f_z = fg_xz[linfun. diff_idxs, linfun. alge_idxs],
212
309
g_x = fg_xz[linfun. alge_idxs, linfun. diff_idxs],
0 commit comments