@@ -3,6 +3,8 @@ using Zygote
3
3
using ModelingToolkit: t_nounits as t, D_nounits as D
4
4
import SymbolicIndexingInterface as SII
5
5
import SciMLStructures as SS
6
+ using ModelingToolkitStandardLibrary
7
+ import ModelingToolkitStandardLibrary as MSL
6
8
7
9
@parameters σ ρ β
8
10
@variables x (t) y (t) z (t) w (t)
@@ -34,60 +36,18 @@ sol = solve(prob, Tsit5())
34
36
du_ = [0.0 , 1.0 , 1.0 , 1.0 ]
35
37
du = [du_ for _ in sol. u]
36
38
@test du == gs. u
37
- end
38
-
39
- # Lorenz
40
-
41
- @parameters σ ρ β
42
- @variables x (t) y (t) z (t)
43
-
44
- eqs = [D (x) ~ σ * (y - x),
45
- D (y) ~ x * (ρ - z) - y,
46
- D (z) ~ x * y - β * z]
47
-
48
- @named lorenz1 = ODESystem (eqs, t)
49
- @named lorenz2 = ODESystem (eqs, t)
50
-
51
- @parameters γ
52
- @variables a (t) α (t)
53
- connections = [0 ~ lorenz1. x + lorenz2. y + a * γ,
54
- α ~ 2 lorenz1. x + a * γ]
55
- @mtkbuild sys = ODESystem (connections, t, [a, α], [γ], systems = [lorenz1, lorenz2])
56
-
57
- u0 = [lorenz1. x => 1.0 ,
58
- lorenz1. y => 0.0 ,
59
- lorenz1. z => 0.0 ,
60
- lorenz2. x => 0.0 ,
61
- lorenz2. y => 1.0 ,
62
- lorenz2. z => 0.0 ,
63
- a => 2.0 ]
64
-
65
- p = [lorenz1. σ => 10.0 ,
66
- lorenz1. ρ => 28.0 ,
67
- lorenz1. β => 8 / 3 ,
68
- lorenz2. σ => 10.0 ,
69
- lorenz2. ρ => 28.0 ,
70
- lorenz2. β => 8 / 3 ,
71
- γ => 2.0 ]
72
-
73
- tspan = (0.0 , 100.0 )
74
- prob = ODEProblem (sys, u0, tspan, p)
75
- integ = init (prob, Rodas4 ())
76
- sol = solve (prob, Rodas4 ())
77
-
78
- gt = reduce (hcat, sol[[sys. a, sys. α]]) .+ randn .()
79
39
80
- gs, = Zygote. gradient (sol) do sol
81
- mean (abs .(sol[[sys. a, sys. α]] .- gt), dims = 2 )
40
+ # Observable in a vector
41
+ gs, = gradient (sol) do sol
42
+ sum (sum .(sol[[sys. w, sys. x]]))
43
+ end
44
+ du_ = [0.0 , 1.0 , 1.0 , 2.0 ]
45
+ du = [du_ for _ in sol. u]
46
+ @test du == gs. u
82
47
end
83
48
84
49
# DAE
85
50
86
- using ModelingToolkit, OrdinaryDiffEq, Zygote
87
- using ModelingToolkitStandardLibrary
88
- import ModelingToolkitStandardLibrary as MSL
89
- using SciMLStructures
90
-
91
51
function create_model (; C₁ = 3e-5 , C₂ = 1e-6 )
92
52
@variables t
93
53
@named resistor1 = MSL. Electrical. Resistor (R = 5.0 )
@@ -112,15 +72,20 @@ function create_model(; C₁ = 3e-5, C₂ = 1e-6)
112
72
])
113
73
end
114
74
115
- model = create_model ()
116
- sys = structural_simplify (model)
75
+ @testset " DAE Observable function AD" begin
76
+ model = create_model ()
77
+ sys = structural_simplify (model)
78
+
79
+ prob = ODEProblem (sys, [], (0.0 , 1.0 ))
80
+ sol = solve (prob, Rodas4 ())
117
81
118
- prob = ODEProblem (sys, [], (0.0 , 1.0 ))
119
- sol = solve (prob, Rodas4 ())
120
- pf = getp (sol, sys. resistor1. R)
121
- mtkparams = SII. parameter_values (sol)
122
- tunables, _, _ = SS. canonicalize (SS. Tunable (), mtkparams)
123
- p_new = rand (length (tunables))
82
+ gs, = gradient (sol) do sol
83
+ sum (sol[sys. ampermeter. i])
84
+ end
85
+ du_ = [0.2 , 1.0 ]
86
+ du = [du_ for _ in sol. u]
87
+ @test gs. u == du
88
+ end
124
89
125
90
# @testset "Adjoints with DAE" begin
126
91
# gs_mtkp, gs_p_new = gradient(mtkparams, p_new) do p, new_tunables
0 commit comments