Skip to content

Commit 96e53ed

Browse files
Merge pull request #3477 from AayushSabharwal/as/linearize-ad
feat: support alternative AD backends in linearization
2 parents 5cfb1b6 + d448fc5 commit 96e53ed

File tree

6 files changed

+136
-79
lines changed

6 files changed

+136
-79
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Yingbo Ma <mayingbo5@gmail.com>", "Chris Rackauckas <accounts@chrisr
44
version = "9.68.1"
55

66
[deps]
7+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
78
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
89
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
910
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
@@ -16,6 +17,7 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1617
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
1718
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
1819
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
20+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1921
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
2022
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
2123
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
@@ -77,6 +79,7 @@ MTKInfiniteOptExt = "InfiniteOpt"
7779
MTKLabelledArraysExt = "LabelledArrays"
7880

7981
[compat]
82+
ADTypes = "1.14.0"
8083
AbstractTrees = "0.3, 0.4"
8184
ArrayInterface = "6, 7"
8285
BifurcationKit = "0.4"
@@ -96,6 +99,7 @@ DiffEqBase = "6.165.1"
9699
DiffEqCallbacks = "2.16, 3, 4"
97100
DiffEqNoiseProcess = "5"
98101
DiffRules = "0.1, 1.0"
102+
DifferentiationInterface = "0.6.47"
99103
Distributed = "1"
100104
Distributions = "0.23, 0.24, 0.25"
101105
DocStringExtensions = "0.7, 0.8, 0.9"
@@ -156,8 +160,8 @@ julia = "1.9"
156160
[extras]
157161
AmplNLWriter = "7c4d4715-977e-5154-bfe0-e096adeac482"
158162
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
159-
BoundaryValueDiffEqMIRK = "1a22d4ce-7765-49ea-b6f2-13c8438986a6"
160163
BoundaryValueDiffEqAscher = "7227322d-7511-4e07-9247-ad6ff830280e"
164+
BoundaryValueDiffEqMIRK = "1a22d4ce-7765-49ea-b6f2-13c8438986a6"
161165
ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
162166
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
163167
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"

src/ModelingToolkit.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ RuntimeGeneratedFunctions.init(@__MODULE__)
9494
import DynamicQuantities, Unitful
9595
const DQ = DynamicQuantities
9696

97+
import DifferentiationInterface as DI
98+
using ADTypes: AutoForwardDiff
99+
97100
export @derivatives
98101

99102
for fun in [:toexpr]

src/linearization.jl

Lines changed: 116 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ The `simplified_sys` has undergone [`structural_simplify`](@ref) and had any occ
2525
- `simplify`: Apply simplification in tearing.
2626
- `initialize`: If true, a check is performed to ensure that the operating point is consistent (satisfies algebraic equations). If the op is not consistent, initialization is performed.
2727
- `initialization_solver_alg`: A NonlinearSolve algorithm to use for solving for a feasible set of state and algebraic variables that satisfies the specified operating point.
28+
- `autodiff`: An `ADType` supported by DifferentiationInterface.jl to use for calculating the necessary jacobians. Defaults to using `AutoForwardDiff()`
2829
- `kwargs`: Are passed on to `find_solvables!`
2930
3031
See also [`linearize`](@ref) which provides a higher-level interface.
@@ -39,6 +40,7 @@ function linearization_function(sys::AbstractSystem, inputs,
3940
p = DiffEqBase.NullParameters(),
4041
zero_dummy_der = false,
4142
initialization_solver_alg = TrustRegion(),
43+
autodiff = AutoForwardDiff(),
4244
eval_expression = false, eval_module = @__MODULE__,
4345
warn_initialize_determined = true,
4446
guesses = Dict(),
@@ -82,13 +84,104 @@ function linearization_function(sys::AbstractSystem, inputs,
8284
initialization_kwargs = (;
8385
abstol = initialization_abstol, reltol = initialization_reltol,
8486
nlsolve_alg = initialization_solver_alg)
87+
88+
p = parameter_values(prob)
89+
t0 = current_time(prob)
90+
inputvals = [p[idx] for idx in input_idxs]
91+
92+
hp_fun = let fun = h, setter = setp_oop(sys, input_idxs)
93+
function hpf(du, input, u, p, t)
94+
p = setter(p, input)
95+
fun(du, u, p, t)
96+
return du
97+
end
98+
end
99+
if u0 === nothing
100+
uf_jac = h_jac = pf_jac = nothing
101+
T = p isa MTKParameters ? eltype(p.tunable) : eltype(p)
102+
hp_jac = PreparedJacobian{true}(
103+
hp_fun, zeros(T, size(outputs)), autodiff, inputvals,
104+
DI.Constant(prob.u0), DI.Constant(p), DI.Constant(t0))
105+
else
106+
uf_fun = let fun = prob.f
107+
function uff(du, u, p, t)
108+
SciMLBase.UJacobianWrapper(fun, t, p)(du, u)
109+
end
110+
end
111+
uf_jac = PreparedJacobian{true}(
112+
uf_fun, similar(prob.u0), autodiff, prob.u0, DI.Constant(p), DI.Constant(t0))
113+
# observed function is a `GeneratedFunctionWrapper` with iip component
114+
h_jac = PreparedJacobian{true}(h, similar(prob.u0, size(outputs)), autodiff,
115+
prob.u0, DI.Constant(p), DI.Constant(t0))
116+
pf_fun = let fun = prob.f, setter = setp_oop(sys, input_idxs)
117+
function pff(du, input, u, p, t)
118+
p = setter(p, input)
119+
SciMLBase.ParamJacobianWrapper(fun, t, u)(du, p)
120+
end
121+
end
122+
pf_jac = PreparedJacobian{true}(pf_fun, similar(prob.u0), autodiff, inputvals,
123+
DI.Constant(prob.u0), DI.Constant(p), DI.Constant(t0))
124+
hp_jac = PreparedJacobian{true}(
125+
hp_fun, similar(prob.u0, size(outputs)), autodiff, inputvals,
126+
DI.Constant(prob.u0), DI.Constant(p), DI.Constant(t0))
127+
end
128+
85129
lin_fun = LinearizationFunction(
86130
diff_idxs, alge_idxs, input_idxs, length(unknowns(sys)),
87-
prob, h, u0 === nothing ? nothing : similar(u0),
88-
ForwardDiff.Chunk(input_idxs), initializealg, initialization_kwargs)
131+
prob, h, u0 === nothing ? nothing : similar(u0), uf_jac, h_jac, pf_jac,
132+
hp_jac, initializealg, initialization_kwargs)
89133
return lin_fun, sys
90134
end
91135

136+
"""
137+
$(TYPEDEF)
138+
139+
Callable struct which stores a function and its prepared `DI.jacobian`. Calling with the
140+
appropriate arguments for DI returns the jacobian.
141+
142+
# Fields
143+
144+
$(TYPEDFIELDS)
145+
"""
146+
struct PreparedJacobian{iip, P, F, B, A}
147+
"""
148+
The preparation object.
149+
"""
150+
prep::P
151+
"""
152+
The function whose jacobian is calculated.
153+
"""
154+
f::F
155+
"""
156+
Buffer for in-place functions.
157+
"""
158+
buf::B
159+
"""
160+
ADType to use for differentiation.
161+
"""
162+
autodiff::A
163+
end
164+
165+
function PreparedJacobian{true}(f, buf, autodiff, args...)
166+
prep = DI.prepare_jacobian(f, buf, autodiff, args...)
167+
return PreparedJacobian{true, typeof(prep), typeof(f), typeof(buf), typeof(autodiff)}(
168+
prep, f, buf, autodiff)
169+
end
170+
171+
function PreparedJacobian{false}(f, autodiff, args...)
172+
prep = DI.prepare_jacobian(f, autodiff, args...)
173+
return PreparedJacobian{true, typeof(prep), typeof(f), Nothing, typeof(autodiff)}(
174+
prep, f, nothing)
175+
end
176+
177+
function (pj::PreparedJacobian{true})(args...)
178+
DI.jacobian(pj.f, pj.buf, pj.prep, pj.autodiff, args...)
179+
end
180+
181+
function (pj::PreparedJacobian{false})(args...)
182+
DI.jacobian(pj.f, pj.prep, pj.autodiff, args...)
183+
end
184+
92185
"""
93186
$(TYPEDEF)
94187
@@ -100,7 +193,7 @@ $(TYPEDFIELDS)
100193
"""
101194
struct LinearizationFunction{
102195
DI <: AbstractVector{Int}, AI <: AbstractVector{Int}, II, P <: ODEProblem,
103-
H, C, Ch, IA <: SciMLBase.DAEInitializationAlgorithm, IK}
196+
H, C, J1, J2, J3, J4, IA <: SciMLBase.DAEInitializationAlgorithm, IK}
104197
"""
105198
The indexes of differential equations in the linearized system.
106199
"""
@@ -130,11 +223,22 @@ struct LinearizationFunction{
130223
Any required cache buffers.
131224
"""
132225
caches::C
133-
# TODO: Use DI?
134226
"""
135-
A `ForwardDiff.Chunk` for taking the jacobian with respect to the inputs.
227+
`PreparedJacobian` for calculating jacobian of `prob.f` w.r.t. `u`
228+
"""
229+
uf_jac::J1
230+
"""
231+
`PreparedJacobian` for calculating jacobian of `h` w.r.t. `u`
136232
"""
137-
chunk::Ch
233+
h_jac::J2
234+
"""
235+
`PreparedJacobian` for calculating jacobian of `prob.f` w.r.t. `p`
236+
"""
237+
pf_jac::J3
238+
"""
239+
`PreparedJacobian` for calculating jacobian of `h` w.r.t. `p`
240+
"""
241+
hp_jac::J4
138242
"""
139243
The initialization algorithm to use.
140244
"""
@@ -188,25 +292,18 @@ function (linfun::LinearizationFunction)(u, p, t)
188292
if !success
189293
error("Initialization algorithm $(linfun.initializealg) failed with `u = $u` and `p = $p`.")
190294
end
191-
uf = SciMLBase.UJacobianWrapper(fun, t, p)
192-
fg_xz = ForwardDiff.jacobian(uf, u)
193-
h_xz = ForwardDiff.jacobian(
194-
let p = p, t = t, h = linfun.h
195-
xz -> h(xz, p, t)
196-
end, u)
197-
pf = SciMLBase.ParamJacobianWrapper(fun, t, u)
198-
fg_u = jacobian_wrt_vars(pf, p, linfun.input_idxs, linfun.chunk)
295+
fg_xz = linfun.uf_jac(u, DI.Constant(p), DI.Constant(t))
296+
h_xz = linfun.h_jac(u, DI.Constant(p), DI.Constant(t))
297+
fg_u = linfun.pf_jac([p[idx] for idx in linfun.input_idxs],
298+
DI.Constant(u), DI.Constant(p), DI.Constant(t))
199299
else
200300
linfun.num_states == 0 ||
201301
error("Number of unknown variables (0) does not match the number of input unknowns ($(length(u)))")
202302
fg_xz = zeros(0, 0)
203303
h_xz = fg_u = zeros(0, length(linfun.input_idxs))
204304
end
205-
hp = let u = u, t = t, h = linfun.h
206-
_hp(p) = h(u, p, t)
207-
_hp
208-
end
209-
h_u = jacobian_wrt_vars(hp, p, linfun.input_idxs, linfun.chunk)
305+
h_u = linfun.hp_jac([p[idx] for idx in linfun.input_idxs],
306+
DI.Constant(u), DI.Constant(p), DI.Constant(t))
210307
(f_x = fg_xz[linfun.diff_idxs, linfun.diff_idxs],
211308
f_z = fg_xz[linfun.diff_idxs, linfun.alge_idxs],
212309
g_x = fg_xz[linfun.alge_idxs, linfun.diff_idxs],

src/systems/parameter_buffer.jl

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -734,35 +734,6 @@ function Base.:(==)(a::MTKParameters, b::MTKParameters)
734734
end)
735735
end
736736

737-
# to support linearize/linearization_function
738-
function jacobian_wrt_vars(pf::F, p::MTKParameters, input_idxs, chunk::C) where {F, C}
739-
tunable, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
740-
T = eltype(tunable)
741-
tag = ForwardDiff.Tag(pf, T)
742-
dualtype = ForwardDiff.Dual{typeof(tag), T, ForwardDiff.chunksize(chunk)}
743-
p_big = SciMLStructures.replace(SciMLStructures.Tunable(), p, dualtype.(tunable))
744-
p_closure = let pf = pf,
745-
input_idxs = input_idxs,
746-
p_big = p_big
747-
748-
function (p_small_inner)
749-
for (i, val) in zip(input_idxs, p_small_inner)
750-
set_parameter!(p_big, val, i)
751-
end
752-
return if pf isa SciMLBase.ParamJacobianWrapper
753-
buffer = Array{dualtype}(undef, size(pf.u))
754-
pf(buffer, p_big)
755-
buffer
756-
else
757-
pf(p_big)
758-
end
759-
end
760-
end
761-
p_small = parameter_values.((p,), input_idxs)
762-
cfg = ForwardDiff.JacobianConfig(p_closure, p_small, chunk, tag)
763-
ForwardDiff.jacobian(p_closure, p_small, cfg, Val(false))
764-
end
765-
766737
const MISSING_PARAMETERS_MESSAGE = """
767738
Some parameters are missing from the variable map.
768739
Please provide a value or default for the following variables:

src/utils.jl

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -963,27 +963,6 @@ function Base.iterate(it::StatefulBFS, queue = (eltype(it)[(0, it.t)]))
963963
return (lv, t), queue
964964
end
965965

966-
function jacobian_wrt_vars(pf::F, p, input_idxs, chunk::C) where {F, C}
967-
E = eltype(p)
968-
tag = ForwardDiff.Tag(pf, E)
969-
T = typeof(tag)
970-
dualtype = ForwardDiff.Dual{T, E, ForwardDiff.chunksize(chunk)}
971-
p_big = similar(p, dualtype)
972-
copyto!(p_big, p)
973-
p_closure = let pf = pf,
974-
input_idxs = input_idxs,
975-
p_big = p_big
976-
977-
function (p_small_inner)
978-
p_big[input_idxs] .= p_small_inner
979-
pf(p_big)
980-
end
981-
end
982-
p_small = p[input_idxs]
983-
cfg = ForwardDiff.JacobianConfig(p_closure, p_small, chunk, tag)
984-
ForwardDiff.jacobian(p_closure, p_small, cfg, Val(false))
985-
end
986-
987966
function fold_constants(ex)
988967
if iscall(ex)
989968
maketerm(typeof(ex), operation(ex), map(fold_constants, arguments(ex)),

test/downstream/linearize.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ModelingToolkit, Test
1+
using ModelingToolkit, ADTypes, Test
22
using CommonSolve: solve
33

44
# r is an input, and y is an output.
@@ -17,11 +17,12 @@ eqs = [u ~ kp * (r - y)
1717
lsys, ssys = linearize(sys, [r], [y])
1818
lprob = LinearizationProblem(sys, [r], [y])
1919
lsys2 = solve(lprob)
20+
lsys3, _ = linearize(sys, [r], [y]; autodiff = AutoFiniteDiff())
2021

21-
@test lsys.A[] == lsys2.A[] == -2
22-
@test lsys.B[] == lsys2.B[] == 1
23-
@test lsys.C[] == lsys2.C[] == 1
24-
@test lsys.D[] == lsys2.D[] == 0
22+
@test lsys.A[] == lsys2.A[] == lsys3.A[] == -2
23+
@test lsys.B[] == lsys2.B[] == lsys3.B[] == 1
24+
@test lsys.C[] == lsys2.C[] == lsys3.C[] == 1
25+
@test lsys.D[] == lsys2.D[] == lsys3.D[] == 0
2526

2627
lsys, ssys = linearize(sys, [r], [r])
2728

@@ -89,11 +90,13 @@ connections = [f.y ~ c.r # filtered reference to controller reference
8990
lsys0, ssys = linearize(cl, [f.u], [p.x])
9091
desired_order = [f.x, p.x]
9192
lsys = ModelingToolkit.reorder_unknowns(lsys0, unknowns(ssys), desired_order)
93+
lsys1, ssys = linearize(cl, [f.u], [p.x]; autodiff = AutoFiniteDiff())
94+
lsys2 = ModelingToolkit.reorder_unknowns(lsys1, unknowns(ssys), desired_order)
9295

93-
@test lsys.A == [-2 0; 1 -2]
94-
@test lsys.B == reshape([1, 0], 2, 1)
95-
@test lsys.C == [0 1]
96-
@test lsys.D[] == 0
96+
@test lsys.A == lsys2.A == [-2 0; 1 -2]
97+
@test lsys.B == lsys2.B == reshape([1, 0], 2, 1)
98+
@test lsys.C == lsys2.C == [0 1]
99+
@test lsys.D[] == lsys2.D[] == 0
97100

98101
## Symbolic linearization
99102
lsyss, _ = ModelingToolkit.linearize_symbolic(cl, [f.u], [p.x])

0 commit comments

Comments
 (0)