@@ -14,7 +14,7 @@ using ComponentArrays, DiffEqFlux, DifferentialEquations, Optimization,
14
14
nn = Chain(Dense(1, 3, tanh), Dense(3, 1, tanh))
15
15
tspan = (0.0f0, 10.0f0)
16
16
17
- ffjord_mdl = FFJORD(nn, tspan, (1,), Tsit5())
17
+ ffjord_mdl = FFJORD(nn, tspan, (1,), Tsit5(); ad = AutoZygote() )
18
18
ps, st = Lux.setup(Random.default_rng(), ffjord_mdl)
19
19
ps = ComponentArray(ps)
20
20
model = Lux.Experimental.StatefulLuxLayer(ffjord_mdl, nothing, st)
@@ -33,7 +33,7 @@ function cb(p, l)
33
33
return false
34
34
end
35
35
36
- adtype = Optimization.AutoZygote ()
36
+ adtype = Optimization.AutoForwardDiff ()
37
37
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
38
38
optprob = Optimization.OptimizationProblem(optf, ps)
39
39
@@ -68,7 +68,7 @@ using ComponentArray, DiffEqFlux, DifferentialEquations, Optimization,
68
68
nn = Chain(Dense(1, 3, tanh), Dense(3, 1, tanh))
69
69
tspan = (0.0f0, 10.0f0)
70
70
71
- ffjord_mdl = FFJORD(nn, tspan, (1,), Tsit5())
71
+ ffjord_mdl = FFJORD(nn, tspan, (1,), Tsit5(); ad = AutoZygote() )
72
72
ps, st = Lux.setup(Random.default_rng(), ffjord_mdl)
73
73
ps = ComponentArray(ps)
74
74
model = Lux.Experimental.StatefulLuxLayer(ps, st, ffjord_mdl)
@@ -109,7 +109,7 @@ Here we showcase starting the optimization with `Adam` to more quickly find a mi
109
109
110
110
``` @example cnf2
111
111
112
- adtype = Optimization.AutoZygote ()
112
+ adtype = Optimization.AutoForwardDiff ()
113
113
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
114
114
optprob = Optimization.OptimizationProblem(optf, ps)
115
115
0 commit comments