Skip to content

Commit 5514952

Browse files
committed
restructre
1 parent 7f234d6 commit 5514952

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ makedocs(
5555
=#
5656
# Not really sure where this belongs... some in Fluxperimental, aim to delete?
5757
"Custom Layers" => "models/advanced.md",
58-
"Freezing model params" => "models/freezing-params.md",
58+
"Advanced tweaking of models" => "tutorials/misc-model-tweaking.md",
5959
],
6060
],
6161
format = Documenter.HTML(

docs/src/models/freezing-params.md renamed to docs/src/tutorials/misc-model-tweaking.md

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
# Freezing model weights
1+
# Choosing differentiable/gpu parts of the model
2+
!!! note
3+
This tutorial features somewhat disconnected topics about customizing your
4+
models even further. It is advised to be familiar with
5+
[`Flux.@layer`](@ref), [`Flux.@functor`](@ref), [`freeze!`](@ref
6+
Flux.freeze!) and other basics of Flux.
7+
28
Flux provides several ways of freezing, excluding from backprop entirely and
39
marking custom struct fields not to be moved to the GPU
410
([Functors.@functor](@ref)) hence excluded from being trained. The following
@@ -37,21 +43,19 @@ end
3743
```
3844

3945
## Static freezing per model definition
40-
Sometimes some parts of the model ([`Flux.@functor`](@ref)) needn't to be trained at all but these params
46+
Sometimes some parts of the model ([`Flux.@layer`](@ref)) needn't to be trained at all but these params
4147
still need to reside on the GPU (these params are still needed in the forward
4248
and/or backward pass).
4349
```julia
4450
struct MaskedLayer{T}
4551
chain::Chain
4652
mask::T
4753
end
48-
Flux.@functor MaskedLayer
49-
50-
# mark the trainable part
51-
Flux.trainable(a::MaskedLayer)=(;a.chain)
52-
# a.mask will not be updated in the training loop
54+
Flux.@layer MyLayer trainable=(chain,)
55+
# mask field will not be updated in the training loop
5356

5457
function (m::MaskedLayer)(x)
58+
# mask field will still move to to gpu for efficient operations:
5559
return m.chain(x) + x + m.mask
5660
end
5761

@@ -61,7 +65,7 @@ Note how this method permanently sets some model fields to be excluded from
6165
training without on-the-fly changing.
6266

6367
## Excluding from model definition
64-
Sometimes some parameters are just "not trainable" but they shouldn't even
68+
Sometimes some parameters aren't just "not trainable" but they shouldn't even
6569
transfer to the GPU. All scalar fields are like this by default, so things like
6670
learning rate multipliers are not trainable nor transferred to the GPU by
6771
default.
@@ -82,7 +86,7 @@ function (m::CustomLayer)(x)
8286
return result
8387
end
8488
```
85-
See more about this in [`Flux.@functor`](@ref) and
89+
See more about this in [`Flux.@functor`](@ref)
8690

8791

8892
## Freezing Layer Parameters (deprecated)

0 commit comments

Comments
 (0)