Skip to content

Commit c6bac9a

Browse files
committed
perhaps we should build regularisation into the same page
1 parent 3d7eb3f commit c6bac9a

File tree

3 files changed

+105
-1
lines changed

3 files changed

+105
-1
lines changed

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ makedocs(
2828
"Training Models" => [
2929
"Training" => "training/training.md",
3030
"Training API 📚" => "training/train_api.md",
31-
"Regularisation" => "models/regularisation.md",
31+
# "Regularisation" => "models/regularisation.md",
3232
"Loss Functions 📚" => "models/losses.md",
3333
"Optimisation Rules 📚" => "training/optimisers.md", # TODO move optimiser intro up to Training
3434
"Callback Helpers 📚" => "training/callbacks.md",

docs/src/training/training.md

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,3 +275,66 @@ For more details on training in the implicit style, see [Flux 0.13.6 documentati
275275

276276
For details about the two gradient modes, see [Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1).
277277

278+
## Regularisation
279+
280+
The term *regularisation* covers a wide variety of techniques aiming to improve the
281+
result of training. This is often done to avoid overfitting.
282+
283+
Some of these are can be implemented by simply modifying the loss function.
284+
L2 or ... umm ... adds to the loss a penalty proportional to `θ^2` for every scalar parameter,
285+
and for a simple model could be implemented as follows:
286+
287+
```julia
288+
Flux.gradient(model) do m
289+
result = m(input)
290+
penalty = sum(abs2, m.weight)/2 + sum(abs2, m.bias)/2
291+
my_loss(result, label) + 0.42 * penalty
292+
end
293+
```
294+
295+
Accessing each individual parameter array by hand won't work well for large models.
296+
Instead, we can use [`Flux.params`](@ref) to collect all of them,
297+
and then apply a function to each one, and sum the result:
298+
299+
```julia
300+
pen_l2(x::AbstractArray) = sum(abs2, x)/2
301+
302+
Flux.gradient(model) do m
303+
result = m(input)
304+
penalty = sum(pen_l2, Flux.params(m))
305+
my_loss(result, label) + 0.42 * penalty
306+
end
307+
```
308+
309+
However, the gradient of this penalty term is very simple: It is proportional to the original weights.
310+
So there is a simpler way to implement exactly the same thing, by modifying the optimiser
311+
instead of the loss function. This is done by replacing this:
312+
313+
```julia
314+
opt = Flux.setup(Adam(0.1), model)
315+
```
316+
317+
with this:
318+
319+
```julia
320+
decay_opt = Flux.setup(OptimiserChain(WeightDecay(0.42), Adam(0.1)), model)
321+
```
322+
323+
Flux's optimisers are really modifications applied to the gradient before using it to update
324+
the parameters, and `OptimiserChain` applies two such modifications.
325+
The first, [`WeightDecay`](@ref) adds `0.42` times original parameter to the gradient,
326+
matching the gradient of the penalty above (with the same, unrealistically large, constant).
327+
After that, in either case, [`Adam`](@ref) computes the final update.
328+
329+
The same mechanism can be used for other purposes, such as gradient clipping with [`ClipGrad`](@ref ).
330+
331+
Besides L2 / weight decay, another common and quite different kind of regularisation is
332+
provided by the [`Dropout`](@ref Flux.Dropout) layer. This turns off some ... ??
333+
334+
?? do we discuss test/train mode here too?
335+
336+
## Freezing, Schedules
337+
338+
?? maybe these also fit in here.
339+
340+

test/train.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,44 @@ end
9191
@test y5 < y4
9292
end
9393

94+
@testset "L2 regularisation" begin
95+
# New docs claim an exact equivalent. It's a bit long to put the example in there,
96+
# but perhaps the tests should contain it.
97+
98+
model = Dense(3 => 2, tanh);
99+
init_weight = copy(model.weight);
100+
data = [(randn(Float32, 3,5), randn(Float32, 2,5)) for _ in 1:10];
101+
102+
# Take 1: explicitly add a penalty in the loss function
103+
opt = Flux.setup(Adam(0.1), model)
104+
Flux.train!(model, data, opt) do m, x, y
105+
err = Flux.mse(m(x), y)
106+
l2 = sum(abs2, m.weight)/2 + sum(abs2, m.bias)/2
107+
err + 0.33 * l2
108+
end
109+
diff1 = model.weight .- init_weight
110+
111+
# Take 2: the same, but with Flux.params. Was broken for a bit, no tests!
112+
model.weight .= init_weight
113+
model.bias .= 0
114+
pen2(x::AbstractArray) = sum(abs2, x)/2
115+
opt = Flux.setup(Adam(0.1), model)
116+
Flux.train!(model, data, opt) do m, x, y
117+
err = Flux.mse(m(x), y)
118+
l2 = sum(pen2, Flux.params(m))
119+
err + 0.33 * l2
120+
end
121+
diff2 = model.weight .- init_weight
122+
@test_broken diff1 diff2
123+
124+
# Take 3: using WeightDecay instead. Need the /2 above, to match exactly.
125+
model.weight .= init_weight
126+
model.bias .= 0
127+
decay_opt = Flux.setup(OptimiserChain(WeightDecay(0.33), Adam(0.1)), model);
128+
Flux.train!(model, data, decay_opt) do m, x, y
129+
Flux.mse(m(x), y)
130+
end
131+
diff3 = model.weight .- init_weight
132+
@test diff1 diff3
133+
end
134+

0 commit comments

Comments
 (0)