Skip to content

Commit 5564ad6

Browse files
Merge pull request #45 from SciML/smc/ad
Add reverse mode AD tests
2 parents c5ae0ce + 40d6e53 commit 5564ad6

File tree

2 files changed

+27
-10
lines changed

2 files changed

+27
-10
lines changed

Project.toml

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,36 +15,42 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
1515
[compat]
1616
Aqua = "0.8"
1717
ComponentArrays = "0.15.11"
18+
DifferentiationInterface = "0.6"
1819
ForwardDiff = "0.10.36"
1920
JET = "0.8, 0.9"
2021
Lux = "1"
2122
LuxCore = "1"
22-
ModelingToolkit = "9.9"
23+
ModelingToolkit = "9.64"
2324
ModelingToolkitStandardLibrary = "2.7"
24-
Optimization = "3.24"
25-
OptimizationOptimisers = "0.2.1"
25+
Optimization = "3.24, 4"
26+
OptimizationOptimisers = "0.2.1, 0.3"
2627
OrdinaryDiffEq = "6.74"
2728
Random = "1.10"
2829
SafeTestsets = "0.1"
30+
SciMLSensitivity = "7.72"
2931
SciMLStructures = "1.1.0"
3032
StableRNGs = "1"
3133
SymbolicIndexingInterface = "0.3.15"
32-
Symbolics = "6.9"
34+
Symbolics = "6.22"
3335
Test = "1.10"
36+
Zygote = "0.6.73"
3437
julia = "1.10"
3538

3639
[extras]
3740
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
41+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
3842
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3943
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
4044
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
4145
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
4246
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
4347
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
48+
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
4449
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
4550
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
4651
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
4752
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
53+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4854

4955
[targets]
50-
test = ["Aqua", "JET", "Test", "OrdinaryDiffEq", "ForwardDiff", "Optimization", "OptimizationOptimisers", "SafeTestsets", "SciMLStructures", "StableRNGs", "SymbolicIndexingInterface"]
56+
test = ["Aqua", "JET", "Test", "OrdinaryDiffEq", "DifferentiationInterface", "SciMLSensitivity", "Zygote", "ForwardDiff", "Optimization", "OptimizationOptimisers", "SafeTestsets", "SciMLStructures", "StableRNGs", "SymbolicIndexingInterface"]

test/lotka_volterra.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ using SciMLStructures
1111
using SciMLStructures: Tunable, canonicalize
1212
using ForwardDiff
1313
using StableRNGs
14+
using DifferentiationInterface
15+
using SciMLSensitivity
16+
using Zygote: Zygote
1417

1518
function lotka_ude()
1619
@variables t x(t)=3.1 y(t)=1.5
@@ -59,7 +62,7 @@ prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys, [], (0, 1.0), [])
5962

6063
model_true = structural_simplify(lotka_true())
6164
prob_true = ODEProblem{true, SciMLBase.FullSpecialize}(model_true, [], (0, 1.0), [])
62-
sol_ref = solve(prob_true, Rodas4())
65+
sol_ref = solve(prob_true, Rodas5P(), abstol = 1e-10, reltol = 1e-8)
6366

6467
x0 = default_values(sys)[nn.p]
6568

@@ -71,7 +74,7 @@ function loss(x, (prob, sol_ref, get_vars, get_refs, set_x))
7174
new_p = set_x(prob, x)
7275
new_prob = remake(prob, p = new_p, u0 = eltype(x).(prob.u0))
7376
ts = sol_ref.t
74-
new_sol = solve(new_prob, Rodas4(), saveat = ts)
77+
new_sol = solve(new_prob, Rodas5P(), abstol = 1e-10, reltol = 1e-8, saveat = ts)
7578

7679
loss = zero(eltype(x))
7780

@@ -86,14 +89,22 @@ function loss(x, (prob, sol_ref, get_vars, get_refs, set_x))
8689
end
8790
end
8891

89-
of = OptimizationFunction{true}(loss, AutoForwardDiff())
92+
of = OptimizationFunction{true}(loss, AutoZygote())
9093

9194
ps = (prob, sol_ref, get_vars, get_refs, set_x);
9295

9396
@test_call target_modules=(ModelingToolkitNeuralNets,) loss(x0, ps)
9497
@test_opt target_modules=(ModelingToolkitNeuralNets,) loss(x0, ps)
9598

96-
@test all(.!isnan.(ForwardDiff.gradient(Base.Fix2(of, ps), x0)))
99+
∇l1 = DifferentiationInterface.gradient(Base.Fix2(of, ps), AutoForwardDiff(), x0)
100+
∇l2 = DifferentiationInterface.gradient(Base.Fix2(of, ps), AutoFiniteDiff(), x0)
101+
∇l3 = DifferentiationInterface.gradient(Base.Fix2(of, ps), AutoZygote(), x0)
102+
103+
@test all(.!isnan.(∇l1))
104+
@test !iszero(∇l1)
105+
106+
@test ∇l1∇l2 rtol=1e-2
107+
@test ∇l1∇l3 rtol=1e-5
97108

98109
op = OptimizationProblem(of, x0, ps)
99110

@@ -111,7 +122,7 @@ op = OptimizationProblem(of, x0, ps)
111122
# false
112123
# end
113124

114-
res = solve(op, Adam(), maxiters = 5000)#, callback = plot_cb)
125+
res = solve(op, Adam(), maxiters = 10000)#, callback = plot_cb)
115126

116127
@test res.objective < 1
117128

0 commit comments

Comments
 (0)