Skip to content

Commit b015b7a

Browse files
authored
Remove train! from quickstart example (#2110)
* remove train from quickstart example * fixes & suggestions * better bullet points * dump train! and gpu from the readme too * remove a few comments * rm mention of Zygote * maybe we should have a much simpler readme example * tweaks * no more cbrt, no more abs2 * remove controversial println code, and make it shorter * fix some fences * maybe this example should run on the GPU, since it easily can, even though this is slower * let's replace explicit printing with showprogress macro, it's pretty and doesn't waste lines * add graph of the loss, since we log it? also move to a folder. * one more .. perhaps
1 parent 065c191 commit b015b7a

File tree

4 files changed

+66
-40
lines changed

4 files changed

+66
-40
lines changed

README.md

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,23 @@
1818

1919
Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable.
2020

21-
Works best with [Julia 1.8](https://julialang.org/downloads/) or later. Here's a simple example to try it out:
21+
Works best with [Julia 1.8](https://julialang.org/downloads/) or later. Here's a very short example to try it out:
2222
```julia
23-
using Flux # should install everything for you, including CUDA
23+
using Flux, Plots
24+
data = [([x], 2x-x^3) for x in -2:0.1f0:2]
2425

25-
x = hcat(digits.(0:3, base=2, pad=2)...) |> gpu # let's solve the XOR problem!
26-
y = Flux.onehotbatch(xor.(eachrow(x)...), 0:1) |> gpu
27-
data = ((Float32.(x), y) for _ in 1:100) # an iterator making Tuples
26+
model = Chain(Dense(1 => 23, tanh), Dense(23 => 1, bias=false), only)
2827

29-
model = Chain(Dense(2 => 3, sigmoid), BatchNorm(3), Dense(3 => 2)) |> gpu
30-
optim = Adam(0.1, (0.7, 0.95))
31-
mloss(x, y) = Flux.logitcrossentropy(model(x), y) # closes over model
28+
mloss(x,y) = (model(x) - y)^2
29+
optim = Flux.Adam()
30+
for epoch in 1:1000
31+
Flux.train!(mloss, Flux.params(model), data, optim)
32+
end
3233

33-
Flux.train!(mloss, Flux.params(model), data, optim) # updates model & optim
34-
35-
all((softmax(model(x)) .> 0.5) .== y) # usually 100% accuracy.
34+
plot(x -> 2x-x^3, -2, 2, legend=false)
35+
scatter!(-2:0.1:2, [model([x]) for x in -2:0.1:2])
3636
```
3737

38-
See the [documentation](https://fluxml.github.io/Flux.jl/) for details, or the [model zoo](https://github.com/FluxML/model-zoo/) for examples. Ask questions on the [Julia discourse](https://discourse.julialang.org/) or [slack](https://discourse.julialang.org/t/announcing-a-julia-slack/4866).
38+
The [quickstart page](https://fluxml.ai/Flux.jl/stable/models/quickstart/) has a longer example. See the [documentation](https://fluxml.github.io/Flux.jl/) for details, or the [model zoo](https://github.com/FluxML/model-zoo/) for examples. Ask questions on the [Julia discourse](https://discourse.julialang.org/) or [slack](https://discourse.julialang.org/t/announcing-a-julia-slack/4866).
3939

4040
If you use Flux in your research, please [cite](CITATION.bib) our work.

docs/src/assets/quickstart/loss.png

61 KB
Loading

docs/src/models/quickstart.md

Lines changed: 54 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,45 +6,54 @@ If you haven't, then you might prefer the [Fitting a Straight Line](overview.md)
66

77
```julia
88
# With Julia 1.7+, this will prompt if neccessary to install everything, including CUDA:
9-
using Flux, Statistics
9+
using Flux, Statistics, ProgressMeter
1010

1111
# Generate some data for the XOR problem: vectors of length 2, as columns of a matrix:
1212
noisy = rand(Float32, 2, 1000) # 2×1000 Matrix{Float32}
13-
truth = map(col -> xor(col...), eachcol(noisy .> 0.5)) # 1000-element Vector{Bool}
13+
truth = [xor(col[1]>0.5, col[2]>0.5) for col in eachcol(noisy)] # 1000-element Vector{Bool}
1414

1515
# Define our model, a multi-layer perceptron with one hidden layer of size 3:
16-
model = Chain(Dense(2 => 3, tanh), BatchNorm(3), Dense(3 => 2), softmax)
16+
model = Chain(
17+
Dense(2 => 3, tanh), # activation function inside layer
18+
BatchNorm(3),
19+
Dense(3 => 2),
20+
softmax) |> gpu # move model to GPU, if available
1721

1822
# The model encapsulates parameters, randomly initialised. Its initial output is:
19-
out1 = model(noisy) # 2×1000 Matrix{Float32}
23+
out1 = model(noisy |> gpu) |> cpu # 2×1000 Matrix{Float32}
2024

21-
# To train the model, we use batches of 64 samples:
22-
mat = Flux.onehotbatch(truth, [true, false]) # 2×1000 OneHotMatrix
23-
data = Flux.DataLoader((noisy, mat), batchsize=64, shuffle=true);
24-
first(data) .|> summary # ("2×64 Matrix{Float32}", "2×64 Matrix{Bool}")
25+
# To train the model, we use batches of 64 samples, and one-hot encoding:
26+
target = Flux.onehotbatch(truth, [true, false]) # 2×1000 OneHotMatrix
27+
loader = Flux.DataLoader((noisy, target) |> gpu, batchsize=64, shuffle=true);
28+
# 16-element DataLoader with first element: (2×64 Matrix{Float32}, 2×64 OneHotMatrix)
2529

2630
pars = Flux.params(model) # contains references to arrays in model
2731
opt = Flux.Adam(0.01) # will store optimiser momentum, etc.
2832

2933
# Training loop, using the whole data set 1000 times:
30-
for epoch in 1:1_000
31-
Flux.train!(pars, data, opt) do x, y
32-
# First argument of train! is a loss function, here defined by a `do` block.
33-
# This gets x and y, each a 2×64 Matrix, from data, and compares:
34-
Flux.crossentropy(model(x), y)
34+
losses = []
35+
@showprogress for epoch in 1:1_000
36+
for (x, y) in loader
37+
loss, grad = Flux.withgradient(pars) do
38+
# Evaluate model and loss inside gradient context:
39+
y_hat = model(x)
40+
Flux.crossentropy(y_hat, y)
41+
end
42+
Flux.update!(opt, pars, grad)
43+
push!(losses, loss) # logging, outside gradient context
3544
end
3645
end
3746

38-
pars # has changed!
47+
pars # parameters, momenta and output have all changed
3948
opt
40-
out2 = model(noisy)
49+
out2 = model(noisy |> gpu) |> cpu # first row is prob. of true, second row p(false)
4150

4251
mean((out2[1,:] .> 0.5) .== truth) # accuracy 94% so far!
4352
```
4453

45-
![](../assets/oneminute.png)
54+
![](../assets/quickstart/oneminute.png)
4655

47-
```
56+
```julia
4857
using Plots # to draw the above figure
4958

5059
p_true = scatter(noisy[1,:], noisy[2,:], zcolor=truth, title="True classification", legend=false)
@@ -54,26 +63,43 @@ p_done = scatter(noisy[1,:], noisy[2,:], zcolor=out2[1,:], title="Trained networ
5463
plot(p_true, p_raw, p_done, layout=(1,3), size=(1000,330))
5564
```
5665

66+
```@raw html
67+
<img align="right" width="300px" src="../../assets/quickstart/loss.png">
68+
```
69+
70+
Here's the loss during training:
71+
72+
```julia
73+
plot(losses; xaxis=(:log10, "iteration"),
74+
yaxis="loss", label="per batch")
75+
n = length(loader)
76+
plot!(n:n:length(losses), mean.(Iterators.partition(losses, n)),
77+
label="epoch mean", dpi=200)
78+
```
79+
5780
This XOR ("exclusive or") problem is a variant of the famous one which drove Minsky and Papert to invent deep neural networks in 1969. For small values of "deep" -- this has one hidden layer, while earlier perceptrons had none. (What they call a hidden layer, Flux calls the output of the first layer, `model[1](noisy)`.)
5881

5982
Since then things have developed a little.
6083

61-
## Features of Note
84+
## Features to Note
6285

6386
Some things to notice in this example are:
6487

65-
* The batch dimension of data is always the last one. Thus a `2×1000 Matrix` is a thousand observations, each a column of length 2.
66-
67-
* The `model` can be called like a function, `y = model(x)`. It encapsulates the parameters (and state).
88+
* The batch dimension of data is always the last one. Thus a `2×1000 Matrix` is a thousand observations, each a column of length 2. Flux defaults to `Float32`, but most of Julia to `Float64`.
6889

69-
* But the model does not contain the loss function, nor the optimisation rule. Instead the [`Adam()`](@ref Flux.Adam) object stores between iterations the momenta it needs.
90+
* The `model` can be called like a function, `y = model(x)`. Each layer like [`Dense`](@ref Flux.Dense) is an ordinary `struct`, which encapsulates some arrays of parameters (and possibly other state, as for [`BatchNorm`](@ref Flux.BatchNorm)).
7091

71-
* The function [`train!`](@ref Flux.train!) likes data as an iterator generating `Tuple`s, here produced by [`DataLoader`](@ref). This mutates both the `model` and the optimiser state inside `opt`.
92+
* But the model does not contain the loss function, nor the optimisation rule. The [`Adam`](@ref Flux.Adam) object stores between iterations the momenta it needs. And [`Flux.crossentropy`](@ref Flux.Losses.crossentropy) is an ordinary function.
7293

73-
There are other ways to train Flux models, for more control than `train!` provides:
94+
* The `do` block creates an anonymous function, as the first argument of `gradient`. Anything executed within this is differentiated.
7495

75-
* Within Flux, you can easily write a training loop, calling [`gradient`](@ref) and [`update!`](@ref Flux.update!).
96+
Instead of calling [`gradient`](@ref Zygote.gradient) and [`update!`](@ref Flux.update!) separately, there is a convenience function [`train!`](@ref Flux.train!). If we didn't want anything extra (like logging the loss), we could replace the training loop with the following:
7697

77-
* For a lower-level way, see the package [Optimisers.jl](https://github.com/FluxML/Optimisers.jl).
78-
79-
* For higher-level ways, see [FluxTraining.jl](https://github.com/FluxML/FluxTraining.jl) and [FastAI.jl](https://github.com/FluxML/FastAI.jl).
98+
```julia
99+
for epoch in 1:1_000
100+
train!(pars, loader, opt) do x, y
101+
y_hat = model(x)
102+
Flux.crossentropy(y_hat, y)
103+
end
104+
end
105+
```

0 commit comments

Comments
 (0)