Skip to content

Commit 40f0a63

Browse files
authored
Re-write training docs (#2114)
* re-write training.md * add train_api page for docstrings * update basic.md to introduce explicit not implicit * more links, comments on notes * updates, rm some Optimisers detail * mention TerminalLoggers * tweaks * perhaps we should build regularisation into the same page * tweaks * update quickstart + readme too * finish freezing etc, update everything * fix a test, etc * add note to "advanced" page * tweaks * comments * tweaks, bugs, missing files, etc * move a sentence * change opt to state * new page lost in rebase * don't say "explicit" so often * opt to state in a few more places * add three compat boxes about common errors / problems re old versions * change to opt_state * fixes * fixup * fixup * fixup * spelling & indentation
1 parent 4f015e9 commit 40f0a63

21 files changed

+639
-355
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ MacroTools = "0.5"
3333
NNlib = "0.8.9"
3434
NNlibCUDA = "0.2.4"
3535
OneHotArrays = "0.1, 0.2"
36-
Optimisers = "0.2.10"
36+
Optimisers = "0.2.12"
3737
ProgressLogging = "0.1"
3838
Reexport = "0.2, 1.0"
3939
SpecialFunctions = "1.8.2, 2.1.2"

README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,9 @@ data = [([x], 2x-x^3) for x in -2:0.1f0:2]
2525

2626
model = Chain(Dense(1 => 23, tanh), Dense(23 => 1, bias=false), only)
2727

28-
mloss(x,y) = (model(x) - y)^2
29-
optim = Flux.Adam()
28+
optim = Flux.setup(Adam(), model)
3029
for epoch in 1:1000
31-
Flux.train!(mloss, Flux.params(model), data, optim)
30+
Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, optim)
3231
end
3332

3433
plot(x -> 2x-x^3, -2, 2, legend=false)

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
33
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
44
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
55
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
6+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
67
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
78
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
89
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"

docs/make.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ makedocs(
1818
"Fitting a Line" => "models/overview.md",
1919
"Gradients and Layers" => "models/basics.md",
2020
"Training" => "training/training.md",
21-
"Regularisation" => "models/regularisation.md", # consolidated in #2114
2221
"Recurrence" => "models/recurrence.md",
2322
"GPU Support" => "gpu.md",
2423
"Saving & Loading" => "saving.md",
@@ -31,7 +30,8 @@ makedocs(
3130
"Activation Functions" => "models/activation.md",
3231
"Weight Initialisation" => "utilities.md",
3332
"Loss Functions" => "models/losses.md",
34-
"Optimisation Rules" => "training/optimisers.md", # TODO move optimiser intro up to Training
33+
"Training API" => "training/reference.md",
34+
"Optimisation Rules" => "training/optimisers.md",
3535
"Shape Inference" => "outputsize.md",
3636
"Flat vs. Nested" => "destructure.md",
3737
"Callback Helpers" => "training/callbacks.md",

docs/src/destructure.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ julia> Flux.destructure(grad) # acts on non-models, too
4949
(Float32[10.339018, 11.379145, 22.845667, -29.565302, -37.644184], Restructure(Tuple, ..., 5))
5050
```
5151

52+
!!! compat "Flux ≤ 0.12"
53+
Old versions of Flux had an entirely different implementation of `destructure`, which
54+
had many bugs (and almost no tests). Many comments online still refer to that now-deleted
55+
function, or to memories of it.
56+
57+
5258
### All Parameters
5359

5460
The function `destructure` now lives in [`Optimisers.jl`](https://github.com/FluxML/Optimisers.jl).

docs/src/models/advanced.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ However, doing this requires the `struct` to have a corresponding constructor th
6969

7070
When it is desired to not include all the model parameters (for e.g. transfer learning), we can simply not pass in those layers into our call to `params`.
7171

72+
!!! compat "Flux ≤ 0.13"
73+
The mechanism described here is for Flux's old "implicit" training style.
74+
When upgrading for Flux 0.14, it should be replaced by [`freeze!`](@ref Flux.freeze!) and `thaw!`.
75+
7276
Consider a simple multi-layer perceptron model where we want to avoid optimising the first two `Dense` layers. We can obtain
7377
this using the slicing features `Chain` provides:
7478

@@ -155,6 +159,10 @@ model(xs)
155159
# returns a single float vector with one value
156160
```
157161

162+
!!! note
163+
This `Join` layer is available from the [Fluxperimental.jl](https://github.com/FluxML/Fluxperimental.jl) package.
164+
165+
158166
#### Using `Parallel`
159167

160168
Flux already provides [`Parallel`](@ref) that can offer the same functionality. In this case, `Join` is going to just be syntactic sugar for `Parallel`.
@@ -223,3 +231,7 @@ function loss(x, ys, model)
223231
return sqrt(mean(Flux.mse(y, ŷ) for (y, ŷ) in zip(ys, ŷs)))
224232
end
225233
```
234+
235+
!!! note
236+
This `Split` layer is available from the [Fluxperimental.jl](https://github.com/FluxML/Fluxperimental.jl) package.
237+

docs/src/models/basics.md

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# [How Flux Works: Gradients and Layers](@id man-basics)
22

3-
## Taking Gradients
3+
## [Taking Gradients](@id man-taking-gradients)
44

55
Flux's core feature is taking gradients of Julia code. The `gradient` function takes another Julia function `f` and a set of arguments, and returns the gradient with respect to each argument. (It's a good idea to try pasting these examples in the Julia terminal.)
66

@@ -29,35 +29,77 @@ julia> gradient(f, [2, 1], [2, 0])
2929
([0.0, 2.0], [-0.0, -2.0])
3030
```
3131

32-
These gradients are based on `x` and `y`. Flux works by instead taking gradients based on the weights and biases that make up the parameters of a model.
32+
These gradients are based on `x` and `y`. Flux works by instead taking gradients based on the weights and biases that make up the parameters of a model.
3333

34-
35-
Machine learning often can have *hundreds* of parameters, so Flux lets you work with collections of parameters, via the `params` functions. You can get the gradient of all parameters used in a program without explicitly passing them in.
34+
Machine learning often can have *hundreds* of parameter arrays.
35+
Instead of passing them to `gradient` individually, we can store them together in a structure.
36+
The simplest example is a named tuple, created by the following syntax:
3637

3738
```jldoctest basics
38-
julia> x = [2, 1];
39+
julia> nt = (a = [2, 1], b = [2, 0], c = tanh);
40+
41+
julia> g(x::NamedTuple) = sum(abs2, x.a .- x.b);
42+
43+
julia> g(nt)
44+
1
45+
46+
julia> dg_nt = gradient(g, nt)[1]
47+
(a = [0.0, 2.0], b = [-0.0, -2.0], c = nothing)
48+
```
49+
50+
Notice that `gradient` has returned a matching structure. The field `dg_nt.a` is the gradient
51+
for `nt.a`, and so on. Some fields have no gradient, indicated by `nothing`.
3952

40-
julia> y = [2, 0];
53+
Rather than define a function like `g` every time (and think up a name for it),
54+
it is often useful to use anonymous functions: this one is `x -> sum(abs2, x.a .- x.b)`.
55+
Anonymous functions can be defined either with `->` or with `do`,
56+
and such `do` blocks are often useful if you have a few steps to perform:
57+
58+
```jldoctest basics
59+
julia> gradient((x, y) -> sum(abs2, x.a ./ y .- x.b), nt, [1, 2])
60+
((a = [0.0, 0.5], b = [-0.0, -1.0], c = nothing), [-0.0, -0.25])
4161
42-
julia> gs = gradient(Flux.params(x, y)) do
43-
f(x, y)
62+
julia> gradient(nt, [1, 2]) do x, y
63+
z = x.a ./ y
64+
sum(abs2, z .- x.b)
4465
end
45-
Grads(...)
66+
((a = [0.0, 0.5], b = [-0.0, -1.0], c = nothing), [-0.0, -0.25])
67+
```
4668

47-
julia> gs[x]
48-
2-element Vector{Float64}:
49-
0.0
50-
2.0
69+
Sometimes you may want to know the value of the function, as well as its gradient.
70+
Rather than calling the function a second time, you can call [`withgradient`](@ref Zygote.withgradient) instead:
5171

52-
julia> gs[y]
53-
2-element Vector{Float64}:
54-
-0.0
55-
-2.0
5672
```
73+
julia> Flux.withgradient(g, nt)
74+
(val = 1, grad = ((a = [0.0, 2.0], b = [-0.0, -2.0], c = nothing),))
75+
```
76+
77+
!!! note "Implicit gradients"
78+
Flux used to handle many parameters in a different way, using the [`params`](@ref Flux.params) function.
79+
This uses a method of `gradient` which takes a zero-argument function, and returns a dictionary
80+
through which the resulting gradients can be looked up:
81+
82+
```jldoctest basics
83+
julia> x = [2, 1];
84+
85+
julia> y = [2, 0];
86+
87+
julia> gs = gradient(Flux.params(x, y)) do
88+
f(x, y)
89+
end
90+
Grads(...)
91+
92+
julia> gs[x]
93+
2-element Vector{Float64}:
94+
0.0
95+
2.0
5796

58-
Here, `gradient` takes a zero-argument function; no arguments are necessary because the `params` tell it what to differentiate.
97+
julia> gs[y]
98+
2-element Vector{Float64}:
99+
-0.0
100+
-2.0
101+
```
59102

60-
This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple.
61103

62104
## Building Simple Models
63105

docs/src/models/layers.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ The `Dense` exemplifies several features:
1515
* It is annotated with [`@functor`](@ref Functors.@functor), which means that [`params`](@ref Flux.params) will see the contents, and [`gpu`](@ref Flux.gpu) will move their arrays to the GPU.
1616

1717
By contrast, `Chain` itself contains no parameters, but connects other layers together.
18-
The section on [dataflow layers](@ref man-dataflow-layers) introduces others like this,
18+
The section on [dataflow layers](@ref man-dataflow-layers) introduces others like this.
1919

2020
## Fully Connected
2121

@@ -27,6 +27,11 @@ Flux.Scale
2727

2828
Perhaps `Scale` isn't quite fully connected, but it may be thought of as `Dense(Diagonal(s.weights), s.bias)`, and LinearAlgebra's `Diagonal` is a matrix which just happens to contain many zeros.
2929

30+
!!! compat "Flux ≤ 0.12"
31+
Old versions of Flux accepted only `Dense(in, out, act)` and not `Dense(in => out, act)`.
32+
This notation makes a `Pair` object. If you get an error like `MethodError: no method matching Dense(::Pair{Int64,Int64})`, this means that you should upgrade to Flux 0.13.
33+
34+
3035
## Convolution Models
3136

3237
These layers are used to build convolutional neural networks (CNNs).

docs/src/models/quickstart.md

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,25 +27,23 @@ target = Flux.onehotbatch(truth, [true, false]) # 2×1000 OneH
2727
loader = Flux.DataLoader((noisy, target) |> gpu, batchsize=64, shuffle=true);
2828
# 16-element DataLoader with first element: (2×64 Matrix{Float32}, 2×64 OneHotMatrix)
2929

30-
pars = Flux.params(model) # contains references to arrays in model
31-
opt = Flux.Adam(0.01) # will store optimiser momentum, etc.
30+
optim = Flux.setup(Flux.Adam(0.01), model) # will store optimiser momentum, etc.
3231

3332
# Training loop, using the whole data set 1000 times:
3433
losses = []
3534
@showprogress for epoch in 1:1_000
3635
for (x, y) in loader
37-
loss, grad = Flux.withgradient(pars) do
36+
loss, grads = Flux.withgradient(model) do m
3837
# Evaluate model and loss inside gradient context:
39-
y_hat = model(x)
38+
y_hat = m(x)
4039
Flux.crossentropy(y_hat, y)
4140
end
42-
Flux.update!(opt, pars, grad)
41+
Flux.update!(optim, model, grads[1])
4342
push!(losses, loss) # logging, outside gradient context
4443
end
4544
end
4645

47-
pars # parameters, momenta and output have all changed
48-
opt
46+
optim # parameters, momenta and output have all changed
4947
out2 = model(noisy |> gpu) |> cpu # first row is prob. of true, second row p(false)
5048

5149
mean((out2[1,:] .> 0.5) .== truth) # accuracy 94% so far!
@@ -89,17 +87,31 @@ Some things to notice in this example are:
8987

9088
* The `model` can be called like a function, `y = model(x)`. Each layer like [`Dense`](@ref Flux.Dense) is an ordinary `struct`, which encapsulates some arrays of parameters (and possibly other state, as for [`BatchNorm`](@ref Flux.BatchNorm)).
9189

92-
* But the model does not contain the loss function, nor the optimisation rule. The [`Adam`](@ref Flux.Adam) object stores between iterations the momenta it needs. And [`Flux.crossentropy`](@ref Flux.Losses.crossentropy) is an ordinary function.
90+
* But the model does not contain the loss function, nor the optimisation rule. The momenta needed by [`Adam`](@ref Flux.Adam) are stored in the object returned by [setup](@ref Flux.Train.setup). And [`Flux.crossentropy`](@ref Flux.Losses.crossentropy) is an ordinary function.
9391

9492
* The `do` block creates an anonymous function, as the first argument of `gradient`. Anything executed within this is differentiated.
9593

9694
Instead of calling [`gradient`](@ref Zygote.gradient) and [`update!`](@ref Flux.update!) separately, there is a convenience function [`train!`](@ref Flux.train!). If we didn't want anything extra (like logging the loss), we could replace the training loop with the following:
9795

9896
```julia
9997
for epoch in 1:1_000
100-
Flux.train!(pars, loader, opt) do x, y
101-
y_hat = model(x)
98+
Flux.train!(model, loader, optim) do m, x, y
99+
y_hat = m(x)
102100
Flux.crossentropy(y_hat, y)
103101
end
104102
end
105103
```
104+
105+
!!! compat "Implicit-style training, Flux ≤ 0.13"
106+
Until recently Flux's training worked a bit differently.
107+
Any code which looks like
108+
```
109+
gradient(() -> loss(model, x, y), Flux.params(model))
110+
```
111+
(gradient of a zero-argument function) or
112+
```
113+
train!((x,y) -> loss(model, x, y), Flux.params(model), loader, opt)
114+
```
115+
(with `Flux.params`) is in the old "implicit" style.
116+
This still works on Flux 0.13, but will be removed from Flux 0.14.
117+
See the [training section](@ref man-training) for more details.

docs/src/models/regularisation.md

Lines changed: 0 additions & 80 deletions
This file was deleted.

0 commit comments

Comments
 (0)