1
- # Freezing model weights
1
+ # Choosing differentiable/gpu parts of the model
2
+ !!! note
3
+ This tutorial features somewhat disconnected topics about customizing your
4
+ models even further. It is advised to be familiar with
5
+ [ ` Flux.@layer ` ] ( @ref ) , [ ` Flux.@functor ` ] ( @ref ) , [ ` freeze! ` ] (@ref
6
+ Flux.freeze!) and other basics of Flux.
7
+
2
8
Flux provides several ways of freezing, excluding from backprop entirely and
3
9
marking custom struct fields not to be moved to the GPU
4
10
([ Functors.@functor ] ( @ref ) ) hence excluded from being trained. The following
37
43
```
38
44
39
45
## Static freezing per model definition
40
- Sometimes some parts of the model ([ ` Flux.@functor ` ] ( @ref ) ) needn't to be trained at all but these params
46
+ Sometimes some parts of the model ([ ` Flux.@layer ` ] ( @ref ) ) needn't to be trained at all but these params
41
47
still need to reside on the GPU (these params are still needed in the forward
42
48
and/or backward pass).
43
49
``` julia
44
50
struct MaskedLayer{T}
45
51
chain:: Chain
46
52
mask:: T
47
53
end
48
- Flux. @functor MaskedLayer
49
-
50
- # mark the trainable part
51
- Flux. trainable (a:: MaskedLayer )= (;a. chain)
52
- # a.mask will not be updated in the training loop
54
+ Flux. @layer MyLayer trainable= (chain,)
55
+ # mask field will not be updated in the training loop
53
56
54
57
function (m:: MaskedLayer )(x)
58
+ # mask field will still move to to gpu for efficient operations:
55
59
return m. chain (x) + x + m. mask
56
60
end
57
61
@@ -61,7 +65,7 @@ Note how this method permanently sets some model fields to be excluded from
61
65
training without on-the-fly changing.
62
66
63
67
## Excluding from model definition
64
- Sometimes some parameters are just "not trainable" but they shouldn't even
68
+ Sometimes some parameters aren't just "not trainable" but they shouldn't even
65
69
transfer to the GPU. All scalar fields are like this by default, so things like
66
70
learning rate multipliers are not trainable nor transferred to the GPU by
67
71
default.
@@ -82,7 +86,7 @@ function (m::CustomLayer)(x)
82
86
return result
83
87
end
84
88
```
85
- See more about this in [ ` Flux.@functor ` ] ( @ref ) and
89
+ See more about this in [ ` Flux.@functor ` ] ( @ref )
86
90
87
91
88
92
## Freezing Layer Parameters (deprecated)
0 commit comments