You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/src/training/training.md
+25-22Lines changed: 25 additions & 22 deletions
Original file line number
Diff line number
Diff line change
@@ -12,6 +12,9 @@ are handled one-by-one. One *epoch* of training means that each example is used
12
12
something like this:
13
13
14
14
```julia
15
+
# Initialise the optimiser for this model:
16
+
state = Flux.setup(rule, model)
17
+
15
18
for data in train_set
16
19
# Unpack this element (for supervised training):
17
20
input, label = data
@@ -24,16 +27,16 @@ for data in train_set
24
27
end
25
28
26
29
# Update the parameters so as to reduce the objective,
27
-
# according to a particular optimiser:
28
-
Flux.update!(opt, model, grads[1])
30
+
# according the chosen optimisation rule:
31
+
Flux.update!(state, model, grads[1])
29
32
end
30
33
```
31
34
32
35
This loop can also be written using the function [`train!`](@ref Flux.Train.train!),
33
36
but it's helpful to undersand the pieces first:
34
37
35
38
```julia
36
-
train!(model, train_set, opt) do m, x, y
39
+
train!(model, train_set, state) do m, x, y
37
40
loss(m(x), y)
38
41
end
39
42
```
@@ -113,7 +116,7 @@ fmap(model, grads[1]) do p, g
113
116
end
114
117
```
115
118
116
-
A slightly more refined version of this loop to update all the parameters is wrapepd up as a function [`update!`](@ref Flux.Optimise.update!)`(opt, model, grads[1])`.
119
+
A slightly more refined version of this loop to update all the parameters is wrapepd up as a function [`update!`](@ref Flux.Optimise.update!)`(state, model, grads[1])`.
117
120
And the learning rate is the only thing stored in the [`Descent`](@ref Flux.Optimise.Descent) struct.
118
121
119
122
However, there are many other optimisation rules, which adjust the step size and
@@ -126,13 +129,13 @@ first argument of `update!`. Like this:
126
129
127
130
```julia
128
131
# Initialise momentum
129
-
opt= Flux.setup(Momentum(0.01, 0.9), model)
132
+
state= Flux.setup(Momentum(0.01, 0.9), model)
130
133
131
134
for data in train_set
132
135
grads = [...]
133
136
134
137
# Update both model parameters and optimiser state:
135
-
Flux.update!(opt, model, grads[1])
138
+
Flux.update!(state, model, grads[1])
136
139
end
137
140
```
138
141
@@ -192,17 +195,17 @@ Simple training loops like the one above can be written compactly using
192
195
the [`train!`](@ref Flux.Train.train!) function. Including `setup`, this reads:
193
196
194
197
```julia
195
-
opt= Flux.setup(Adam(), model)
198
+
state= Flux.setup(Adam(), model)
196
199
197
200
for epoch in1:100
198
-
Flux.train!(model, train_set, opt) do m, x, y
201
+
Flux.train!(model, train_set, state) do m, x, y
199
202
loss(m(x), y)
200
203
end
201
204
end
202
205
```
203
206
204
207
Or explicitly writing the anonymous function which this `do` block creates,
205
-
`train!((m,x,y) -> loss(m(x),y), model, train_set, opt)` is exactly equivalent.
208
+
`train!((m,x,y) -> loss(m(x),y), model, train_set, state)` is exactly equivalent.
206
209
207
210
!!! compat "Implicit-style `train!`"
208
211
This is the new "explicit" method of `train!`, which takes the result of `setup` as its 4th argument.
@@ -224,7 +227,7 @@ callback API. Here is an example, in which it may be helpful to note:
224
227
* Julia's `break` and `continue` keywords let you exit from parts of the loop.
225
228
226
229
```julia
227
-
opt= Flux.setup(Adam(), model)
230
+
state= Flux.setup(Adam(), model)
228
231
229
232
my_log = []
230
233
for epoch in1:100
@@ -248,7 +251,7 @@ for epoch in 1:100
248
251
continue
249
252
end
250
253
251
-
Flux.update!(opt, model, grads[1])
254
+
Flux.update!(state, model, grads[1])
252
255
end
253
256
254
257
# Compute some accuracy, and save details as a NamedTuple
@@ -300,13 +303,13 @@ So there is a simpler way to implement exactly the same thing, by modifying the
300
303
instead of the loss function. This is done by replacing this:
0 commit comments