Skip to content

Commit 57beb23

Browse files
authored
Improve docs for initialisation (#1912)
* add some methods * init documentation * fixup * note for identity_init * keyword bug * test that Float64 keywords don't promote * branch once in orthogonal * can't use ;;; on 1.6
1 parent dc6f286 commit 57beb23

File tree

4 files changed

+239
-184
lines changed

4 files changed

+239
-184
lines changed

docs/src/utilities.md

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,37 +3,58 @@
33
Flux provides utility functions which can be used to initialize your layers
44
or to regularly execute callback functions.
55

6-
## Layer Initialization
6+
## Layer Initialisation
77

8-
These are primarily useful if you are planning to write your own layers.
9-
Flux initializes convolutional layers and recurrent cells with `glorot_uniform`
10-
by default.
11-
To change the default on an applicable layer, pass the desired function with the
12-
`init` keyword. For example:
8+
Flux initialises convolutional layers and recurrent cells with `glorot_uniform` by default.
9+
Most layers accept a function as an `init` keyword, which replaces this default. For example:
1310

1411
```jldoctest; setup = :(using Flux)
15-
julia> conv = Conv((3, 3), 1 => 8, relu; init=Flux.glorot_normal)
16-
Conv((3, 3), 1 => 8, relu) # 80 parameters
12+
julia> conv = Conv((3, 3), 3 => 2, relu; init=Flux.glorot_normal)
13+
Conv((3, 3), 3 => 2, relu) # 56 parameters
14+
15+
julia> conv.bias
16+
2-element Vector{Float32}:
17+
0.0
18+
0.0
19+
```
20+
21+
Note that `init` creates the weight array, but not the bias vector.
22+
23+
Many of the initialisation functions accept keywords such as `gain`,
24+
and a random number generator. To make it easy to pass these to layers,
25+
there are methods which return a function:
26+
27+
```jldoctest; setup = :(using Flux, Random)
28+
julia> Dense(4 => 5, tanh; init=Flux.glorot_uniform(gain=2))
29+
Dense(4 => 5, tanh) # 25 parameters
30+
31+
julia> Dense(4 => 5, tanh; init=Flux.randn32(MersenneTwister(1)))
32+
Dense(4 => 5, tanh) # 25 parameters
1733
```
1834

1935
```@docs
2036
Flux.glorot_uniform
2137
Flux.glorot_normal
2238
Flux.kaiming_uniform
2339
Flux.kaiming_normal
40+
Flux.truncated_normal
2441
Flux.orthogonal
2542
Flux.sparse_init
43+
Flux.identity_init
44+
Flux.ones32
45+
Flux.rand32
2646
```
2747

2848
## Changing the type of model parameters
2949

50+
The default `eltype` for models is `Float32` since models are often trained/run on GPUs.
51+
The `eltype` of model `m` can be changed to `Float64` by `f64(m)`:
52+
3053
```@docs
3154
Flux.f64
3255
Flux.f32
3356
```
3457

35-
The default `eltype` for models is `Float32` since models are often trained/run on GPUs. The `eltype` of model `m` can be changed to `Float64` by `f64(m)`, or to `Float32` by `f32(m)`.
36-
3758
## Model Building
3859

3960
Flux provides some utility functions to help you generate models in an automated fashion.

src/functor.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,14 +213,16 @@ paramtype(T::Type{<:Real}, m) = fmap(x -> adapt(T, x), m)
213213
"""
214214
f32(m)
215215
216-
Convert the `eltype` of model's parameters to `Float32`.
216+
Converts the `eltype` of model's parameters to `Float32` (which is Flux's default).
217+
Recurses into structs marked with [`@functor`](@ref).
217218
"""
218219
f32(m) = paramtype(Float32, m)
219220

220221
"""
221222
f64(m)
222223
223-
Convert the `eltype` of model's parameters to `Float64`.
224+
Converts the `eltype` of model's parameters to `Float64`.
225+
Recurses into structs marked with [`@functor`](@ref).
224226
"""
225227
f64(m) = paramtype(Float64, m)
226228

0 commit comments

Comments
 (0)