@@ -8,7 +8,7 @@ using SymbolicIndexingInterface
8
8
using Optimization
9
9
using OptimizationOptimisers: Adam
10
10
using SciMLStructures
11
- using SciMLStructures: Tunable
11
+ using SciMLStructures: Tunable, canonicalize
12
12
using ForwardDiff
13
13
using StableRNGs
14
14
@@ -51,7 +51,7 @@ eqs = [connect(model.nn_in, nn.output)
51
51
52
52
ude_sys = complete (ODESystem (
53
53
eqs, ModelingToolkit. t_nounits, systems = [model, nn],
54
- name = :ude_sys , defaults = [nn . input . u => [ 0.0 , 0.0 ]] ))
54
+ name = :ude_sys ))
55
55
56
56
sys = structural_simplify (ude_sys)
57
57
@@ -61,13 +61,14 @@ model_true = structural_simplify(lotka_true())
61
61
prob_true = ODEProblem {true, SciMLBase.FullSpecialize} (model_true, [], (0 , 1.0 ), [])
62
62
sol_ref = solve (prob_true, Rodas4 ())
63
63
64
- x0 = reduce (vcat, getindex .(( default_values (sys),), tunable_parameters (sys)))
64
+ x0 = default_values (sys)[nn . p]
65
65
66
66
get_vars = getu (sys, [sys. lotka. x, sys. lotka. y])
67
67
get_refs = getu (model_true, [model_true. x, model_true. y])
68
+ set_x = setp_oop (sys, nn. p)
68
69
69
- function loss (x, (prob, sol_ref, get_vars, get_refs))
70
- new_p = SciMLStructures . replace ( Tunable (), prob. p , x)
70
+ function loss (x, (prob, sol_ref, get_vars, get_refs, set_x ))
71
+ new_p = set_x ( prob, x)
71
72
new_prob = remake (prob, p = new_p, u0 = eltype (x).(prob. u0))
72
73
ts = sol_ref. t
73
74
new_sol = solve (new_prob, Rodas4 (), saveat = ts)
87
88
88
89
of = OptimizationFunction {true} (loss, AutoForwardDiff ())
89
90
90
- ps = (prob, sol_ref, get_vars, get_refs);
91
+ ps = (prob, sol_ref, get_vars, get_refs, set_x );
91
92
92
93
@test_call target_modules= (ModelingToolkitNeuralNets,) loss (x0, ps)
93
94
@test_opt target_modules= (ModelingToolkitNeuralNets,) loss (x0, ps)
94
95
95
96
@test all (.! isnan .(ForwardDiff. gradient (Base. Fix2 (of, ps), x0)))
96
97
97
- op = OptimizationProblem (of, x0, (prob, sol_ref, get_vars, get_refs) )
98
+ op = OptimizationProblem (of, x0, ps )
98
99
99
100
# using Plots
100
101
@@ -114,7 +115,7 @@ res = solve(op, Adam(), maxiters = 5000)#, callback = plot_cb)
114
115
115
116
@test res. objective < 1
116
117
117
- res_p = SciMLStructures . replace ( Tunable (), prob. p , res. u)
118
+ res_p = set_x ( prob, res. u)
118
119
res_prob = remake (prob, p = res_p)
119
120
res_sol = solve (res_prob, Rodas4 (), saveat = sol_ref. t)
120
121
0 commit comments