|
1 | 1 | # Optimisers.jl
|
2 | 2 |
|
3 |
| -## Defining an optimisation rule |
| 3 | +## An optimisation rule |
4 | 4 |
|
5 | 5 | A new optimiser must overload two functions, [`apply!`](@ref) and [`init`](@ref).
|
6 | 6 | These act on one array of parameters:
|
@@ -60,18 +60,18 @@ Notice that a completely new instance of the model is returned. Internally, this
|
60 | 60 | is handled by [Functors.jl](https://fluxml.ai/Functors.jl), where we do a walk over the
|
61 | 61 | tree formed by the model and update the parameters using the gradients.
|
62 | 62 |
|
| 63 | +There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state, |
| 64 | +but is free to mutate arrays within the old one for efficiency. |
| 65 | +The method of `apply!` for each rule is likewise free to mutate arrays within its state; |
| 66 | +they are defensively copied when this rule is used with `update`. |
| 67 | + |
63 | 68 | Optimisers.jl does not depend on any one automatic differentiation package,
|
64 | 69 | but for now the most likely source of gradients is [Zygote.jl](https://fluxml.ai/Zygote.jl).
|
65 | 70 | Note that `update` always wants the gradient from Zygote's "explicit" mode, as shown above.
|
66 | 71 | This `∇model` is another tree structure, rather than the dictionary-like object from
|
67 | 72 | Zygote's "implicit" mode `gradient(() -> loss(...), Flux.params(model))` -- see
|
68 | 73 | [Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) for more about this difference.
|
69 | 74 |
|
70 |
| -There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state, |
71 |
| -but is free to mutate arrays within the old one for efficiency. |
72 |
| -The method of `apply!` you write is likewise free to mutate arrays within its state; |
73 |
| -they are defensively copied when this rule is used with `update`. |
74 |
| - |
75 | 75 | ## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)
|
76 | 76 |
|
77 | 77 | The main design difference of Lux is that the tree of parameters is separate from
|
@@ -110,6 +110,57 @@ Besides the parameters stored in `params` and gradually optimised, any other mod
|
110 | 110 | is stored in `lux_state`. For simplicity this example does not show how to propagate the
|
111 | 111 | updated `lux_state` to the next iteration, see Lux's documentation.
|
112 | 112 |
|
| 113 | +## Non-`trainable` Parameters |
| 114 | + |
| 115 | +Optimisers.jl uses [Functors.jl](https://fluxml.ai/Functors.jl) to walk the `struct`s |
| 116 | +making up the model, for which they must be annotated `@functor Type`. |
| 117 | +By default optimisation will alter all [`isnumeric`](@ref) arrays. |
| 118 | + |
| 119 | +If some arrays of a particular layer should not be treated this way, |
| 120 | +you can define a method for [`trainable`](@ref) |
| 121 | + |
| 122 | +```julia |
| 123 | +struct Layer{T} |
| 124 | + alpha::T |
| 125 | + beta::T |
| 126 | + length::Int |
| 127 | +end |
| 128 | +Layer(n::Int) = Layer(randn(n), zeros(n), n) |
| 129 | + |
| 130 | +Functors.@functor Layer |
| 131 | + |
| 132 | +# Both array fields will be, for example, moved to the GPU: |
| 133 | +Functors.children(Layer(3)) # (alpha = [...], beta = [...], length) |
| 134 | + |
| 135 | +Optimisers.trainable(x::Layer) = (; alpha = x.alpha) # must be a subset of chidlren |
| 136 | + |
| 137 | +# Only the first field will be optimised: |
| 138 | +st = Optimisers.setup(DecayDescent(0.1), Layer(3)) |
| 139 | +``` |
| 140 | + |
| 141 | +## Tied Parameters |
| 142 | + |
| 143 | +If the same array appears twice (or more) in the model, [Functors.jl](https://fluxml.ai/Functors.jl) should recognise this. |
| 144 | +Within Optimisers.jl, `setup` will initialise once, and use the same `Leaf` for both parameters. |
| 145 | +Then `update` will accumulate the gradient from both, and the updated model returned will have the tie maintained. |
| 146 | + |
| 147 | +```julia |
| 148 | +using Flux, Optimisers |
| 149 | + |
| 150 | +enc = Chain(Dense(40 => 20, tanh), Dense(20 => 10)); |
| 151 | +dec = Chain(Dense(enc[1].weight', true, tanh), Dense(enc[2].weight', true, tanh)); |
| 152 | +model = Chain(; enc, dec) |
| 153 | + |
| 154 | +st = Optimisers.setup(Optimisers.Adam(), model); |
| 155 | + |
| 156 | +st.layers.enc.layers[1].weight === st.layers.dec.layers[1].weight.parent # true |
| 157 | +``` |
| 158 | + |
| 159 | +This identification relies on `===`, and will work for ordinary `Array`s and `CuArray`s. |
| 160 | +It will not at present work for `reshape`d arrays, nor for immutable arrays such as those |
| 161 | +from StaticArrays.jl. |
| 162 | + |
| 163 | + |
113 | 164 | ## Obtaining a flat parameter vector
|
114 | 165 |
|
115 | 166 | Instead of a nested tree-like structure, sometimes is is convenient to have all the
|
@@ -143,10 +194,11 @@ st, flat = Optimisers.update(st, flat, ∇flat)
|
143 | 194 | ```
|
144 | 195 |
|
145 | 196 | Here `flat` contains only the 283 trainable parameters, while the non-trainable
|
146 |
| -ones are preserved inside `re`. |
| 197 | +ones are preserved inside `re`, an object of type `Restructure`. |
147 | 198 | When defining new layers, these can be specified if necessary by overloading [`trainable`](@ref).
|
148 | 199 | By default, all numeric arrays visible to [Functors.jl](https://github.com/FluxML/Functors.jl)
|
149 | 200 | are assumed to contain trainable parameters.
|
| 201 | +Tied parameters (arrays appearing in different layers) are included only once in `flat`. |
150 | 202 |
|
151 | 203 | Lux stores only the trainable parameters in `params`.
|
152 | 204 | This can also be flattened to a plain `Vector` in the same way:
|
|
0 commit comments