Skip to content

Commit e22bfd4

Browse files
Merge pull request #50 from AayushSabharwal/as/fix-test
test: fix initialization, account for `Initial` parameters in test
2 parents 3fe2727 + 2f855d3 commit e22bfd4

File tree

2 files changed

+15
-13
lines changed

2 files changed

+15
-13
lines changed

docs/src/friction.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ sys = structural_simplify(ude_sys)
110110
We now setup the loss function and the optimization loop.
111111

112112
```@example friction
113-
function loss(x, (prob, sol_ref, get_vars, get_refs))
114-
new_p = SciMLStructures.replace(Tunable(), prob.p, x)
113+
function loss(x, (prob, sol_ref, get_vars, get_refs, set_x))
114+
new_p = set_x(prob, x)
115115
new_prob = remake(prob, p = new_p, u0 = eltype(x).(prob.u0))
116116
ts = sol_ref.t
117117
new_sol = solve(new_prob, Rodas4(), saveat = ts, abstol = 1e-8, reltol = 1e-8)
@@ -131,14 +131,15 @@ of = OptimizationFunction{true}(loss, AutoForwardDiff())
131131
prob = ODEProblem(sys, [], (0, 0.1), [])
132132
get_vars = getu(sys, [sys.friction.y])
133133
get_refs = getu(model_true, [model_true.y])
134-
x0 = reduce(vcat, getindex.((default_values(sys),), tunable_parameters(sys)))
134+
set_x = setp_oop(sys, sys.nn.p)
135+
x0 = default_values(sys)[sys.nn.p]
135136
136137
cb = (opt_state, loss) -> begin
137138
@info "step $(opt_state.iter), loss: $loss"
138139
return false
139140
end
140141
141-
op = OptimizationProblem(of, x0, (prob, sol_ref, get_vars, get_refs))
142+
op = OptimizationProblem(of, x0, (prob, sol_ref, get_vars, get_refs, set_x))
142143
res = solve(op, Adam(5e-3); maxiters = 10000, callback = cb)
143144
```
144145

@@ -147,7 +148,7 @@ res = solve(op, Adam(5e-3); maxiters = 10000, callback = cb)
147148
We now have a trained neural network! We can check whether running the simulation of the model embedded with the neural network matches the data or not.
148149

149150
```@example friction
150-
res_p = SciMLStructures.replace(Tunable(), prob.p, res.u)
151+
res_p = set_x(prob, res.u)
151152
res_prob = remake(prob, p = res_p)
152153
res_sol = solve(res_prob, Rodas4(), saveat = sol_ref.t)
153154
@test first.(sol_ref.u)≈first.(res_sol.u) rtol=1e-3 #hide

test/lotka_volterra.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using SymbolicIndexingInterface
88
using Optimization
99
using OptimizationOptimisers: Adam
1010
using SciMLStructures
11-
using SciMLStructures: Tunable
11+
using SciMLStructures: Tunable, canonicalize
1212
using ForwardDiff
1313
using StableRNGs
1414

@@ -51,7 +51,7 @@ eqs = [connect(model.nn_in, nn.output)
5151

5252
ude_sys = complete(ODESystem(
5353
eqs, ModelingToolkit.t_nounits, systems = [model, nn],
54-
name = :ude_sys, defaults = [nn.input.u => [0.0, 0.0]]))
54+
name = :ude_sys))
5555

5656
sys = structural_simplify(ude_sys)
5757

@@ -61,13 +61,14 @@ model_true = structural_simplify(lotka_true())
6161
prob_true = ODEProblem{true, SciMLBase.FullSpecialize}(model_true, [], (0, 1.0), [])
6262
sol_ref = solve(prob_true, Rodas4())
6363

64-
x0 = reduce(vcat, getindex.((default_values(sys),), tunable_parameters(sys)))
64+
x0 = default_values(sys)[nn.p]
6565

6666
get_vars = getu(sys, [sys.lotka.x, sys.lotka.y])
6767
get_refs = getu(model_true, [model_true.x, model_true.y])
68+
set_x = setp_oop(sys, nn.p)
6869

69-
function loss(x, (prob, sol_ref, get_vars, get_refs))
70-
new_p = SciMLStructures.replace(Tunable(), prob.p, x)
70+
function loss(x, (prob, sol_ref, get_vars, get_refs, set_x))
71+
new_p = set_x(prob, x)
7172
new_prob = remake(prob, p = new_p, u0 = eltype(x).(prob.u0))
7273
ts = sol_ref.t
7374
new_sol = solve(new_prob, Rodas4(), saveat = ts)
@@ -87,14 +88,14 @@ end
8788

8889
of = OptimizationFunction{true}(loss, AutoForwardDiff())
8990

90-
ps = (prob, sol_ref, get_vars, get_refs);
91+
ps = (prob, sol_ref, get_vars, get_refs, set_x);
9192

9293
@test_call target_modules=(ModelingToolkitNeuralNets,) loss(x0, ps)
9394
@test_opt target_modules=(ModelingToolkitNeuralNets,) loss(x0, ps)
9495

9596
@test all(.!isnan.(ForwardDiff.gradient(Base.Fix2(of, ps), x0)))
9697

97-
op = OptimizationProblem(of, x0, (prob, sol_ref, get_vars, get_refs))
98+
op = OptimizationProblem(of, x0, ps)
9899

99100
# using Plots
100101

@@ -114,7 +115,7 @@ res = solve(op, Adam(), maxiters = 5000)#, callback = plot_cb)
114115

115116
@test res.objective < 1
116117

117-
res_p = SciMLStructures.replace(Tunable(), prob.p, res.u)
118+
res_p = set_x(prob, res.u)
118119
res_prob = remake(prob, p = res_p)
119120
res_sol = solve(res_prob, Rodas4(), saveat = sol_ref.t)
120121

0 commit comments

Comments
 (0)