@@ -11,6 +11,9 @@ using SciMLStructures
11
11
using SciMLStructures: Tunable, canonicalize
12
12
using ForwardDiff
13
13
using StableRNGs
14
+ using DifferentiationInterface
15
+ using SciMLSensitivity
16
+ using Zygote: Zygote
14
17
15
18
function lotka_ude ()
16
19
@variables t x (t)= 3.1 y (t)= 1.5
@@ -59,7 +62,7 @@ prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys, [], (0, 1.0), [])
59
62
60
63
model_true = structural_simplify (lotka_true ())
61
64
prob_true = ODEProblem {true, SciMLBase.FullSpecialize} (model_true, [], (0 , 1.0 ), [])
62
- sol_ref = solve (prob_true, Rodas4 () )
65
+ sol_ref = solve (prob_true, Rodas5P (), abstol = 1e-10 , reltol = 1e-8 )
63
66
64
67
x0 = default_values (sys)[nn. p]
65
68
@@ -71,7 +74,7 @@ function loss(x, (prob, sol_ref, get_vars, get_refs, set_x))
71
74
new_p = set_x (prob, x)
72
75
new_prob = remake (prob, p = new_p, u0 = eltype (x).(prob. u0))
73
76
ts = sol_ref. t
74
- new_sol = solve (new_prob, Rodas4 () , saveat = ts)
77
+ new_sol = solve (new_prob, Rodas5P (), abstol = 1e-10 , reltol = 1e-8 , saveat = ts)
75
78
76
79
loss = zero (eltype (x))
77
80
@@ -86,14 +89,22 @@ function loss(x, (prob, sol_ref, get_vars, get_refs, set_x))
86
89
end
87
90
end
88
91
89
- of = OptimizationFunction {true} (loss, AutoForwardDiff ())
92
+ of = OptimizationFunction {true} (loss, AutoZygote ())
90
93
91
94
ps = (prob, sol_ref, get_vars, get_refs, set_x);
92
95
93
96
@test_call target_modules= (ModelingToolkitNeuralNets,) loss (x0, ps)
94
97
@test_opt target_modules= (ModelingToolkitNeuralNets,) loss (x0, ps)
95
98
96
- @test all (.! isnan .(ForwardDiff. gradient (Base. Fix2 (of, ps), x0)))
99
+ ∇l1 = DifferentiationInterface. gradient (Base. Fix2 (of, ps), AutoForwardDiff (), x0)
100
+ ∇l2 = DifferentiationInterface. gradient (Base. Fix2 (of, ps), AutoFiniteDiff (), x0)
101
+ ∇l3 = DifferentiationInterface. gradient (Base. Fix2 (of, ps), AutoZygote (), x0)
102
+
103
+ @test all (.! isnan .(∇l1))
104
+ @test ! iszero (∇l1)
105
+
106
+ @test ∇l1≈ ∇l2 rtol= 1e-2
107
+ @test ∇l1≈ ∇l3 rtol= 1e-5
97
108
98
109
op = OptimizationProblem (of, x0, ps)
99
110
@@ -111,7 +122,7 @@ op = OptimizationProblem(of, x0, ps)
111
122
# false
112
123
# end
113
124
114
- res = solve (op, Adam (), maxiters = 5000 )# , callback = plot_cb)
125
+ res = solve (op, Adam (), maxiters = 10000 )# , callback = plot_cb)
115
126
116
127
@test res. objective < 1
117
128
0 commit comments