Skip to content

Commit e439ae7

Browse files
mcabbottdarsnack
andauthored
Update description of trainable in "advanced.md" (#2289)
* Update advanced.md * Update docs/src/models/advanced.md Co-authored-by: Kyle Daruwalla <daruwalla@wisc.edu> --------- Co-authored-by: Kyle Daruwalla <daruwalla@wisc.edu>
1 parent 8c23af3 commit e439ae7

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

docs/src/models/advanced.md

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,34 +36,38 @@ For an intro to Flux and automatic differentiation, see this [tutorial](https://
3636

3737
Taking reference from our example `Affine` layer from the [basics](@ref man-basics).
3838

39-
By default all the fields in the `Affine` type are collected as its parameters, however, in some cases it may be desired to hold other metadata in our "layers" that may not be needed for training, and are hence supposed to be ignored while the parameters are collected. With Flux, it is possible to mark the fields of our layers that are trainable in two ways.
40-
41-
The first way of achieving this is through overloading the `trainable` function.
39+
By default all the fields in the `Affine` type are collected as its parameters, however, in some cases it may be desired to hold other metadata in our "layers" that may not be needed for training, and are hence supposed to be ignored while the parameters are collected. With Flux, the way to mark some fields of our layer as trainable is through overloading the `trainable` function:
4240

4341
```julia-repl
44-
julia> @functor Affine
42+
julia> Flux.@functor Affine
4543
46-
julia> a = Affine(rand(3,3), rand(3))
47-
Affine{Array{Float64,2},Array{Float64,1}}([0.66722 0.774872 0.249809; 0.843321 0.403843 0.429232; 0.683525 0.662455 0.065297], [0.42394, 0.0170927, 0.544955])
44+
julia> a = Affine(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9])
45+
Affine(Float32[1.0 2.0; 3.0 4.0; 5.0 6.0], Float32[7.0, 8.0, 9.0])
4846
4947
julia> Flux.params(a) # default behavior
50-
Params([[0.66722 0.774872 0.249809; 0.843321 0.403843 0.429232; 0.683525 0.662455 0.065297], [0.42394, 0.0170927, 0.544955]])
48+
Params([Float32[1.0 2.0; 3.0 4.0; 5.0 6.0], Float32[7.0, 8.0, 9.0]])
5149
52-
julia> Flux.trainable(a::Affine) = (a.W,)
50+
julia> Flux.trainable(a::Affine) = (; a.W) # returns a NamedTuple using the field's name
5351
5452
julia> Flux.params(a)
55-
Params([[0.66722 0.774872 0.249809; 0.843321 0.403843 0.429232; 0.683525 0.662455 0.065297]])
53+
Params([Float32[1.0 2.0; 3.0 4.0; 5.0 6.0]])
5654
```
5755

58-
Only the fields returned by `trainable` will be collected as trainable parameters of the layer when calling `Flux.params`.
56+
Only the fields returned by `trainable` will be collected as trainable parameters of the layer when calling `Flux.params`, and only these fields will be seen by `Flux.setup` and `Flux.update!` for training. But all fields wil be seen by `gpu` and similar functions, for example:
57+
58+
```julia-repl
59+
julia> a |> f16
60+
Affine(Float16[1.0 2.0; 3.0 4.0; 5.0 6.0], Float16[7.0, 8.0, 9.0])
61+
```
5962

60-
Another way of achieving this is through the `@functor` macro directly. Here, we can mark the fields we are interested in by grouping them in the second argument:
63+
Note that there is no need to overload `trainable` to hide fields which do not contain trainable parameters. (For example, activation functions, or Boolean flags.) These are always ignored by `params` and by training:
6164

62-
```julia
63-
Flux.@functor Affine (W,)
65+
```julia-repl
66+
julia> Flux.params(Affine(true, [10, 11, 12.0]))
67+
Params([])
6468
```
6569

66-
However, doing this requires the `struct` to have a corresponding constructor that accepts those parameters.
70+
It is also possible to further restrict what fields are seen by writing `@functor Affine (W,)`. However, this is not recommended. This requires the `struct` to have a corresponding constructor that accepts only `W` as an argument, and the ignored fields will not be seen by functions like `gpu` (which is usually undesired).
6771

6872
## Freezing Layer Parameters
6973

0 commit comments

Comments
 (0)