Skip to content

Commit 383fb84

Browse files
committed
docs: improve freezing docs
1 parent 48a43db commit 383fb84

File tree

3 files changed

+129
-41
lines changed

3 files changed

+129
-41
lines changed

docs/make.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ makedocs(
5454
"Deep Convolutional GAN" => "tutorials/2021-10-08-dcgan-mnist.md",
5555
=#
5656
# Not really sure where this belongs... some in Fluxperimental, aim to delete?
57-
"Custom Layers" => "models/advanced.md", # TODO move freezing to Training
57+
"Custom Layers" => "models/advanced.md",
58+
"Freezing model params" => "models/freezing-params.md",
5859
],
5960
],
6061
format = Documenter.HTML(

docs/src/models/advanced.md

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -69,46 +69,6 @@ Params([])
6969

7070
It is also possible to further restrict what fields are seen by writing `@functor Affine (W,)`. However, this is not recommended. This requires the `struct` to have a corresponding constructor that accepts only `W` as an argument, and the ignored fields will not be seen by functions like `gpu` (which is usually undesired).
7171

72-
## Freezing Layer Parameters
73-
74-
When it is desired to not include all the model parameters (for e.g. transfer learning), we can simply not pass in those layers into our call to `params`.
75-
76-
!!! compat "Flux ≤ 0.14"
77-
The mechanism described here is for Flux's old "implicit" training style.
78-
When upgrading for Flux 0.15, it should be replaced by [`freeze!`](@ref Flux.freeze!) and `thaw!`.
79-
80-
Consider a simple multi-layer perceptron model where we want to avoid optimising the first two `Dense` layers. We can obtain
81-
this using the slicing features `Chain` provides:
82-
83-
```julia
84-
m = Chain(
85-
Dense(784 => 64, relu),
86-
Dense(64 => 64, relu),
87-
Dense(32 => 10)
88-
);
89-
90-
ps = Flux.params(m[3:end])
91-
```
92-
93-
The `Zygote.Params` object `ps` now holds a reference to only the parameters of the layers passed to it.
94-
95-
During training, the gradients will only be computed for (and applied to) the last `Dense` layer, therefore only that would have its parameters changed.
96-
97-
`Flux.params` also takes multiple inputs to make it easy to collect parameters from heterogenous models with a single call. A simple demonstration would be if we wanted to omit optimising the second `Dense` layer in the previous example. It would look something like this:
98-
99-
```julia
100-
Flux.params(m[1], m[3:end])
101-
```
102-
103-
Sometimes, a more fine-tuned control is needed.
104-
We can freeze a specific parameter of a specific layer which already entered a `Params` object `ps`,
105-
by simply deleting it from `ps`:
106-
107-
```julia
108-
ps = Flux.params(m)
109-
delete!(ps, m[2].bias)
110-
```
111-
11272
## Custom multiple input or output layer
11373

11474
Sometimes a model needs to receive several separate inputs at once or produce several separate outputs at once. In other words, there multiple paths within this high-level layer, each processing a different input or producing a different output. A simple example of this in machine learning literature is the [inception module](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Szegedy_Rethinking_the_Inception_CVPR_2016_paper.pdf).

docs/src/models/freezing-params.md

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Freezing model weights
2+
Flux provides several ways of freezing, excluding from backprop entirely and
3+
marking custom struct fields not to be moved to the GPU
4+
([Functors.@functor](@ref)) hence excluded from being trained. The following
5+
subsections should make it clear which one suits your needs the best.
6+
7+
## On-the-fly freezing per model instance
8+
Perhaps you'd like to freeze some of the weights of the model (even at
9+
mid-training), and Flux accomplishes this through [`freeze!`](@ref Flux.freeze!) and `thaw!`.
10+
11+
```julia
12+
m = Chain(
13+
Dense(784 => 64, relu), # freeze this one
14+
Dense(64 => 64, relu),
15+
Dense(32 => 10)
16+
)
17+
opt_state = Flux.setup(Momentum(), m);
18+
19+
# Freeze some layers right away
20+
Flux.freeze!(opt_state.layers[1])
21+
22+
for data in train_set
23+
input, label = data
24+
25+
# Some params could be frozen during the training:
26+
Flux.freeze!(opt_state.layers[2])
27+
28+
grads = Flux.gradient(m) do m
29+
result = m(input)
30+
loss(result, label)
31+
end
32+
Flux.update!(opt_state, m, grads[1])
33+
34+
# Optionally unfreeze the params later
35+
Flux.thaw!(opt_state.layers[1])
36+
end
37+
```
38+
39+
## Static freezing per model definition
40+
Sometimes some parts of the model ([`Flux.@functor`](@ref)) needn't to be trained at all but these params
41+
still need to reside on the GPU (these params are still needed in the forward
42+
and/or backward pass).
43+
```julia
44+
struct MaskedLayer{T}
45+
chain::Chain
46+
mask::T
47+
end
48+
Flux.@functor MaskedLayer
49+
50+
# mark the trainable part
51+
Flux.trainable(a::MaskedLayer)=(;a.chain)
52+
# a.mask will not be updated in the training loop
53+
54+
function (m::MaskedLayer)(x)
55+
return m.chain(x) + x + m.mask
56+
end
57+
58+
model = MaskedLayer(...) # this model will not have the `mask` field trained
59+
```
60+
Note how this method permanently sets some model fields to be excluded from
61+
training without on-the-fly changing.
62+
63+
## Excluding from model definition
64+
Sometimes some parameters are just "not trainable" but they shouldn't even
65+
transfer to the GPU. All scalar fields are like this by default, so things like
66+
learning rate multipliers are not trainable nor transferred to the GPU by
67+
default.
68+
```julia
69+
struct CustomLayer{T, F}
70+
chain::Chain
71+
activation_results::Vector{F}
72+
lr_multiplier::Float32
73+
end
74+
Flux.@functor CustomLayer (chain, ) # Explicitly leaving out `activation_results`
75+
76+
function (m::CustomLayer)(x)
77+
result = m.chain(x) + x
78+
79+
# `activation_results` are not part of the GPU loop, hence we could do
80+
# things like `push!`
81+
push!(m.activation_results, mean(result))
82+
return result
83+
end
84+
```
85+
See more about this in [`Flux.@functor`](@ref) and
86+
87+
88+
## Freezing Layer Parameters (deprecated)
89+
90+
When it is desired to not include all the model parameters (for e.g. transfer learning), we can simply not pass in those layers into our call to `params`.
91+
92+
!!! compat "Flux ≤ 0.14"
93+
The mechanism described here is for Flux's old "implicit" training style.
94+
When upgrading for Flux 0.15, it should be replaced by [`freeze!`](@ref Flux.freeze!) and `thaw!`.
95+
96+
Consider a simple multi-layer perceptron model where we want to avoid optimising the first two `Dense` layers. We can obtain
97+
this using the slicing features `Chain` provides:
98+
99+
```julia
100+
m = Chain(
101+
Dense(784 => 64, relu),
102+
Dense(64 => 64, relu),
103+
Dense(32 => 10)
104+
);
105+
106+
ps = Flux.params(m[3:end])
107+
```
108+
109+
The `Zygote.Params` object `ps` now holds a reference to only the parameters of the layers passed to it.
110+
111+
During training, the gradients will only be computed for (and applied to) the last `Dense` layer, therefore only that would have its parameters changed.
112+
113+
`Flux.params` also takes multiple inputs to make it easy to collect parameters from heterogenous models with a single call. A simple demonstration would be if we wanted to omit optimising the second `Dense` layer in the previous example. It would look something like this:
114+
115+
```julia
116+
Flux.params(m[1], m[3:end])
117+
```
118+
119+
Sometimes, a more fine-tuned control is needed.
120+
We can freeze a specific parameter of a specific layer which already entered a `Params` object `ps`,
121+
by simply deleting it from `ps`:
122+
123+
```julia
124+
ps = Flux.params(m)
125+
delete!(ps, m[2].bias)
126+
```
127+

0 commit comments

Comments
 (0)