Skip to content

Commit 44bfc91

Browse files
build: add MSL to test deps
1 parent 032b927 commit 44bfc91

File tree

3 files changed

+26
-61
lines changed

3 files changed

+26
-61
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ Logging = "1.10"
6868
Makie = "0.20"
6969
Markdown = "1.10"
7070
ModelingToolkit = "8.75, 9"
71+
ModelingToolkitStandardLibrary = "2.7"
7172
PartialFunctions = "1.1"
7273
PrecompileTools = "1.2"
7374
Preferences = "1.3"
@@ -96,6 +97,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
9697
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
9798
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
9899
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
100+
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
99101
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
100102
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
101103
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
@@ -111,4 +113,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
111113
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
112114

113115
[targets]
114-
test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "OrdinaryDiffEq", "ForwardDiff"]
116+
test = ["Pkg", "Plots", "UnicodePlots", "PyCall", "PythonCall", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "ModelingToolkitStandardLibrary", "OrdinaryDiffEq", "ForwardDiff"]

ext/SciMLBaseZygoteExt.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,7 @@ end
220220
function solu_adjoint(Δ)
221221
zerou = zero(sol.prob.u0)
222222
= @. ifelse=== nothing, (zerou,), Δ)
223-
nt = Zygote.nt_nothing(sol)
224-
gs = Zygote.accum(nt, (u = _Δ,))
225-
(gs,)
223+
(build_solution(sol.prob, sol.alg, sol.t, _Δ),)
226224
end
227225
sol.u, solu_adjoint
228226
end

test/downstream/observables_autodiff.jl

Lines changed: 22 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ using Zygote
33
using ModelingToolkit: t_nounits as t, D_nounits as D
44
import SymbolicIndexingInterface as SII
55
import SciMLStructures as SS
6+
using ModelingToolkitStandardLibrary
7+
import ModelingToolkitStandardLibrary as MSL
68

79
@parameters σ ρ β
810
@variables x(t) y(t) z(t) w(t)
@@ -34,60 +36,18 @@ sol = solve(prob, Tsit5())
3436
du_ = [0.0, 1.0, 1.0, 1.0]
3537
du = [du_ for _ in sol.u]
3638
@test du == gs.u
37-
end
38-
39-
# Lorenz
40-
41-
@parameters σ ρ β
42-
@variables x(t) y(t) z(t)
43-
44-
eqs = [D(x) ~ σ * (y - x),
45-
D(y) ~ x *- z) - y,
46-
D(z) ~ x * y - β * z]
47-
48-
@named lorenz1 = ODESystem(eqs, t)
49-
@named lorenz2 = ODESystem(eqs, t)
50-
51-
@parameters γ
52-
@variables a(t) α(t)
53-
connections = [0 ~ lorenz1.x + lorenz2.y + a * γ,
54-
α ~ 2lorenz1.x + a * γ]
55-
@mtkbuild sys = ODESystem(connections, t, [a, α], [γ], systems = [lorenz1, lorenz2])
56-
57-
u0 = [lorenz1.x => 1.0,
58-
lorenz1.y => 0.0,
59-
lorenz1.z => 0.0,
60-
lorenz2.x => 0.0,
61-
lorenz2.y => 1.0,
62-
lorenz2.z => 0.0,
63-
a => 2.0]
64-
65-
p = [lorenz1.σ => 10.0,
66-
lorenz1.ρ => 28.0,
67-
lorenz1.β => 8 / 3,
68-
lorenz2.σ => 10.0,
69-
lorenz2.ρ => 28.0,
70-
lorenz2.β => 8 / 3,
71-
γ => 2.0]
72-
73-
tspan = (0.0, 100.0)
74-
prob = ODEProblem(sys, u0, tspan, p)
75-
integ = init(prob, Rodas4())
76-
sol = solve(prob, Rodas4())
77-
78-
gt = reduce(hcat, sol[[sys.a, sys.α]]) .+ randn.()
7939

80-
gs, = Zygote.gradient(sol) do sol
81-
mean(abs.(sol[[sys.a, sys.α]] .- gt), dims = 2)
40+
# Observable in a vector
41+
gs, = gradient(sol) do sol
42+
sum(sum.(sol[[sys.w, sys.x]]))
43+
end
44+
du_ = [0.0, 1.0, 1.0, 2.0]
45+
du = [du_ for _ in sol.u]
46+
@test du == gs.u
8247
end
8348

8449
# DAE
8550

86-
using ModelingToolkit, OrdinaryDiffEq, Zygote
87-
using ModelingToolkitStandardLibrary
88-
import ModelingToolkitStandardLibrary as MSL
89-
using SciMLStructures
90-
9151
function create_model(; C₁ = 3e-5, C₂ = 1e-6)
9252
@variables t
9353
@named resistor1 = MSL.Electrical.Resistor(R = 5.0)
@@ -112,15 +72,20 @@ function create_model(; C₁ = 3e-5, C₂ = 1e-6)
11272
])
11373
end
11474

115-
model = create_model()
116-
sys = structural_simplify(model)
75+
@testset "DAE Observable function AD" begin
76+
model = create_model()
77+
sys = structural_simplify(model)
78+
79+
prob = ODEProblem(sys, [], (0.0, 1.0))
80+
sol = solve(prob, Rodas4())
11781

118-
prob = ODEProblem(sys, [], (0.0, 1.0))
119-
sol = solve(prob, Rodas4())
120-
pf = getp(sol, sys.resistor1.R)
121-
mtkparams = SII.parameter_values(sol)
122-
tunables, _, _ = SS.canonicalize(SS.Tunable(), mtkparams)
123-
p_new = rand(length(tunables))
82+
gs, = gradient(sol) do sol
83+
sum(sol[sys.ampermeter.i])
84+
end
85+
du_ = [0.2, 1.0]
86+
du = [du_ for _ in sol.u]
87+
@test gs.u == du
88+
end
12489

12590
# @testset "Adjoints with DAE" begin
12691
# gs_mtkp, gs_p_new = gradient(mtkparams, p_new) do p, new_tunables

0 commit comments

Comments
 (0)