Skip to content

Commit 11e4825

Browse files
committed
change opt to state
1 parent 28091df commit 11e4825

File tree

4 files changed

+33
-30
lines changed

4 files changed

+33
-30
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ data = [([x], 2x-x^3) for x in -2:0.1f0:2]
2525

2626
model = Chain(Dense(1 => 23, tanh), Dense(23 => 1, bias=false), only)
2727

28-
optim = Flux.setup(Adam(), model)
28+
state = Flux.setup(Adam(), model)
2929
for epoch in 1:1000
30-
Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, optim)
30+
Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, state)
3131
end
3232

3333
plot(x -> 2x-x^3, -2, 2, legend=false)

docs/src/models/quickstart.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ target = Flux.onehotbatch(truth, [true, false]) # 2×1000 OneH
2727
loader = Flux.DataLoader((noisy, target) |> gpu, batchsize=64, shuffle=true);
2828
# 16-element DataLoader with first element: (2×64 Matrix{Float32}, 2×64 OneHotMatrix)
2929

30-
optim = Flux.setup(Flux.Adam(0.01), model) # will store optimiser momentum, etc.
30+
state = Flux.setup(Flux.Adam(0.01), model) # will store optimiser momentum, etc.
3131

3232
# Training loop, using the whole data set 1000 times:
3333
losses = []
@@ -38,12 +38,12 @@ losses = []
3838
y_hat = m(x)
3939
Flux.crossentropy(y_hat, y)
4040
end
41-
Flux.update!(optim, model, grads[1])
41+
Flux.update!(state, model, grads[1])
4242
push!(losses, loss) # logging, outside gradient context
4343
end
4444
end
4545

46-
optim # parameters, momenta and output have all changed
46+
state # parameters, momenta and output have all changed
4747
out2 = model(noisy |> gpu) |> cpu # first row is prob. of true, second row p(false)
4848

4949
mean((out2[1,:] .> 0.5) .== truth) # accuracy 94% so far!
@@ -95,7 +95,7 @@ Instead of calling [`gradient`](@ref Zygote.gradient) and [`update!`](@ref Flux.
9595

9696
```julia
9797
for epoch in 1:1_000
98-
Flux.train!(model, loader, optim) do m, x, y
98+
Flux.train!(model, loader, state) do m, x, y
9999
y_hat = m(x)
100100
Flux.crossentropy(y_hat, y)
101101
end
@@ -110,7 +110,7 @@ end
110110
```
111111
(gradient of a zero-argument function) or
112112
```
113-
train!((x,y) -> loss(model, x, y), Flux.params(model), loader, optim)
113+
train!((x,y) -> loss(model, x, y), Flux.params(model), loader, opt)
114114
```
115115
(with `Flux.params`) is in the old "implicit" style.
116116
This still works on Flux 0.13, but will be removed from Flux 0.14.

docs/src/training/reference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ see the [Optimisers documentation](https://fluxml.ai/Optimisers.jl/dev/) for det
1616

1717
```@docs
1818
Flux.Train.setup
19-
Flux.Train.train!(loss, model, data, opt; cb)
19+
Flux.Train.train!(loss, model, data, state; cb)
2020
Optimisers.update!
2121
```
2222

docs/src/training/training.md

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ are handled one-by-one. One *epoch* of training means that each example is used
1212
something like this:
1313

1414
```julia
15+
# Initialise the optimiser for this model:
16+
state = Flux.setup(rule, model)
17+
1518
for data in train_set
1619
# Unpack this element (for supervised training):
1720
input, label = data
@@ -24,16 +27,16 @@ for data in train_set
2427
end
2528

2629
# Update the parameters so as to reduce the objective,
27-
# according to a particular optimiser:
28-
Flux.update!(opt, model, grads[1])
30+
# according the chosen optimisation rule:
31+
Flux.update!(state, model, grads[1])
2932
end
3033
```
3134

3235
This loop can also be written using the function [`train!`](@ref Flux.Train.train!),
3336
but it's helpful to undersand the pieces first:
3437

3538
```julia
36-
train!(model, train_set, opt) do m, x, y
39+
train!(model, train_set, state) do m, x, y
3740
loss(m(x), y)
3841
end
3942
```
@@ -113,7 +116,7 @@ fmap(model, grads[1]) do p, g
113116
end
114117
```
115118

116-
A slightly more refined version of this loop to update all the parameters is wrapepd up as a function [`update!`](@ref Flux.Optimise.update!)`(opt, model, grads[1])`.
119+
A slightly more refined version of this loop to update all the parameters is wrapepd up as a function [`update!`](@ref Flux.Optimise.update!)`(state, model, grads[1])`.
117120
And the learning rate is the only thing stored in the [`Descent`](@ref Flux.Optimise.Descent) struct.
118121

119122
However, there are many other optimisation rules, which adjust the step size and
@@ -126,13 +129,13 @@ first argument of `update!`. Like this:
126129

127130
```julia
128131
# Initialise momentum
129-
opt = Flux.setup(Momentum(0.01, 0.9), model)
132+
state = Flux.setup(Momentum(0.01, 0.9), model)
130133

131134
for data in train_set
132135
grads = [...]
133136

134137
# Update both model parameters and optimiser state:
135-
Flux.update!(opt, model, grads[1])
138+
Flux.update!(state, model, grads[1])
136139
end
137140
```
138141

@@ -192,17 +195,17 @@ Simple training loops like the one above can be written compactly using
192195
the [`train!`](@ref Flux.Train.train!) function. Including `setup`, this reads:
193196

194197
```julia
195-
opt = Flux.setup(Adam(), model)
198+
state = Flux.setup(Adam(), model)
196199

197200
for epoch in 1:100
198-
Flux.train!(model, train_set, opt) do m, x, y
201+
Flux.train!(model, train_set, state) do m, x, y
199202
loss(m(x), y)
200203
end
201204
end
202205
```
203206

204207
Or explicitly writing the anonymous function which this `do` block creates,
205-
`train!((m,x,y) -> loss(m(x),y), model, train_set, opt)` is exactly equivalent.
208+
`train!((m,x,y) -> loss(m(x),y), model, train_set, state)` is exactly equivalent.
206209

207210
!!! compat "Implicit-style `train!`"
208211
This is the new "explicit" method of `train!`, which takes the result of `setup` as its 4th argument.
@@ -224,7 +227,7 @@ callback API. Here is an example, in which it may be helpful to note:
224227
* Julia's `break` and `continue` keywords let you exit from parts of the loop.
225228

226229
```julia
227-
opt = Flux.setup(Adam(), model)
230+
state = Flux.setup(Adam(), model)
228231

229232
my_log = []
230233
for epoch in 1:100
@@ -248,7 +251,7 @@ for epoch in 1:100
248251
continue
249252
end
250253

251-
Flux.update!(opt, model, grads[1])
254+
Flux.update!(state, model, grads[1])
252255
end
253256

254257
# Compute some accuracy, and save details as a NamedTuple
@@ -300,13 +303,13 @@ So there is a simpler way to implement exactly the same thing, by modifying the
300303
instead of the loss function. This is done by replacing this:
301304

302305
```julia
303-
opt = Flux.setup(Adam(0.1), model)
306+
state = Flux.setup(Adam(0.1), model)
304307
```
305308

306309
with this:
307310

308311
```julia
309-
decay_opt = Flux.setup(OptimiserChain(WeightDecay(0.42), Adam(0.1)), model)
312+
decay_state = Flux.setup(OptimiserChain(WeightDecay(0.42), Adam(0.1)), model)
310313
```
311314

312315
Flux's optimisers are really modifications applied to the gradient before using it to update
@@ -328,12 +331,12 @@ Finer control of training, you may wish to alter the learning rate mid-way throu
328331
This can be done with [`adjust!`](@ref Flux.adjust!), like this:
329332

330333
```julia
331-
opt = Flux.setup(Adam(0.1), model) # initialise once
334+
state = Flux.setup(Adam(0.1), model) # initialise once
332335

333336
for epoch in 1:1000
334-
train!([...], opt) # Train with η = 0.1 for first 100,
337+
train!([...], state) # Train with η = 0.1 for first 100,
335338
if epoch == 100 # then change to use η = 0.01 for the rest.
336-
Flux.adjust!(opt, 0.01)
339+
Flux.adjust!(state, 0.01)
337340
end
338341
end
339342
```
@@ -342,7 +345,7 @@ end
342345
With the old "implicit" optimiser, `opt = Adam(0.1)`, the equivalent was to
343346
directly mutate the `Adam` struct, `opt.eta = 0.001`.
344347

345-
Other hyper-parameters can also be adjusted, such as `Flux.adjust!(opt, beta = (0.8, 0.99))`.
348+
Other hyper-parameters can also be adjusted, such as `Flux.adjust!(state, beta = (0.8, 0.99))`.
346349
And such modifications can be applied to just one part of the model.
347350
For instance, this sets a different learning rate for the encoder and the decoder:
348351

@@ -351,23 +354,23 @@ For instance, this sets a different learning rate for the encoder and the decode
351354
bimodel = Chain(enc = [...], dec = [...])
352355

353356
# This returns a tree whose structure matches the model:
354-
opt = Flux.setup(Adam(0.02), bimodel)
357+
state = Flux.setup(Adam(0.02), bimodel)
355358

356359
# Adjust the learning rate to be used for bimodel.layers.enc
357-
Flux.adjust!(opt.layers.enc, 0.03)
360+
Flux.adjust!(state.layers.enc, 0.03)
358361
```
359362

360363
To completely disable training of some part of the model, use [`freeze!`](@ref Flux.freeze!).
361364
This is a temporary modification, reversed by `thaw!`:
362365

363366
```julia
364-
Flux.freeze!(opt.layers.enc)
367+
Flux.freeze!(state.layers.enc)
365368

366369
# Now training won't update parameters in bimodel.layers.enc
367-
train!(loss, bimodel, data, opt)
370+
train!(loss, bimodel, data, state)
368371

369372
# Un-freeze the entire model:
370-
Flux.thaw!(opt)
373+
Flux.thaw!(state)
371374
```
372375

373376
!!! compat "Flux ≤ 0.13"

0 commit comments

Comments
 (0)