Skip to content

Commit 0d6619a

Browse files
committed
docs etc
1 parent 37521c8 commit 0d6619a

File tree

4 files changed

+70
-10
lines changed

4 files changed

+70
-10
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Optimisers"
22
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
33
authors = ["Mike J Innes <mike.j.innes@gmail.com>"]
4-
version = "0.2.9"
4+
version = "0.2.10"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1212

1313
[compat]
1414
ChainRulesCore = "1"
15-
Functors = "0.2.8, 0.3"
15+
Functors = "0.3"
1616
Zygote = "0.6.40"
1717
julia = "1.6"
1818

docs/src/index.md

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Optimisers.jl
22

3-
## Defining an optimisation rule
3+
## An optimisation rule
44

55
A new optimiser must overload two functions, [`apply!`](@ref) and [`init`](@ref).
66
These act on one array of parameters:
@@ -60,18 +60,18 @@ Notice that a completely new instance of the model is returned. Internally, this
6060
is handled by [Functors.jl](https://fluxml.ai/Functors.jl), where we do a walk over the
6161
tree formed by the model and update the parameters using the gradients.
6262

63+
There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state,
64+
but is free to mutate arrays within the old one for efficiency.
65+
The method of `apply!` for each rule is likewise free to mutate arrays within its state;
66+
they are defensively copied when this rule is used with `update`.
67+
6368
Optimisers.jl does not depend on any one automatic differentiation package,
6469
but for now the most likely source of gradients is [Zygote.jl](https://fluxml.ai/Zygote.jl).
6570
Note that `update` always wants the gradient from Zygote's "explicit" mode, as shown above.
6671
This `∇model` is another tree structure, rather than the dictionary-like object from
6772
Zygote's "implicit" mode `gradient(() -> loss(...), Flux.params(model))` -- see
6873
[Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) for more about this difference.
6974

70-
There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state,
71-
but is free to mutate arrays within the old one for efficiency.
72-
The method of `apply!` you write is likewise free to mutate arrays within its state;
73-
they are defensively copied when this rule is used with `update`.
74-
7575
## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)
7676

7777
The main design difference of Lux is that the tree of parameters is separate from
@@ -110,6 +110,57 @@ Besides the parameters stored in `params` and gradually optimised, any other mod
110110
is stored in `lux_state`. For simplicity this example does not show how to propagate the
111111
updated `lux_state` to the next iteration, see Lux's documentation.
112112

113+
## Non-`trainable` Parameters
114+
115+
Optimisers.jl uses [Functors.jl](https://fluxml.ai/Functors.jl) to walk the `struct`s
116+
making up the model, for which they must be annotated `@functor Type`.
117+
By default optimisation will alter all [`isnumeric`](@ref) arrays.
118+
119+
If some arrays of a particular layer should not be treated this way,
120+
you can define a method for [`trainable`](@ref)
121+
122+
```julia
123+
struct Layer{T}
124+
alpha::T
125+
beta::T
126+
length::Int
127+
end
128+
Layer(n::Int) = Layer(randn(n), zeros(n), n)
129+
130+
Functors.@functor Layer
131+
132+
# Both array fields will be, for example, moved to the GPU:
133+
Functors.children(Layer(3)) # (alpha = [...], beta = [...], length)
134+
135+
Optimisers.trainable(x::Layer) = (; alpha = x.alpha) # must be a subset of chidlren
136+
137+
# Only the first field will be optimised:
138+
st = Optimisers.setup(DecayDescent(0.1), Layer(3))
139+
```
140+
141+
## Tied Parameters
142+
143+
If the same array appears twice (or more) in the model, [Functors.jl](https://fluxml.ai/Functors.jl) should recognise this.
144+
Within Optimisers.jl, `setup` will initialise once, and use the same `Leaf` for both parameters.
145+
Then `update` will accumulate the gradient from both, and the updated model returned will have the tie maintained.
146+
147+
```julia
148+
using Flux, Optimisers
149+
150+
enc = Chain(Dense(40 => 20, tanh), Dense(20 => 10));
151+
dec = Chain(Dense(enc[1].weight', true, tanh), Dense(enc[2].weight', true, tanh));
152+
model = Chain(; enc, dec)
153+
154+
st = Optimisers.setup(Optimisers.Adam(), model);
155+
156+
st.layers.enc.layers[1].weight === st.layers.dec.layers[1].weight.parent # true
157+
```
158+
159+
This identification relies on `===`, and will work for ordinary `Array`s and `CuArray`s.
160+
It will not at present work for `reshape`d arrays, nor for immutable arrays such as those
161+
from StaticArrays.jl.
162+
163+
113164
## Obtaining a flat parameter vector
114165

115166
Instead of a nested tree-like structure, sometimes is is convenient to have all the
@@ -143,10 +194,11 @@ st, flat = Optimisers.update(st, flat, ∇flat)
143194
```
144195

145196
Here `flat` contains only the 283 trainable parameters, while the non-trainable
146-
ones are preserved inside `re`.
197+
ones are preserved inside `re`, an object of type `Restructure`.
147198
When defining new layers, these can be specified if necessary by overloading [`trainable`](@ref).
148199
By default, all numeric arrays visible to [Functors.jl](https://github.com/FluxML/Functors.jl)
149200
are assumed to contain trainable parameters.
201+
Tied parameters (arrays appearing in different layers) are included only once in `flat`.
150202

151203
Lux stores only the trainable parameters in `params`.
152204
This can also be flattened to a plain `Vector` in the same way:

src/Optimisers.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
1616
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
1717
WeightDecay, ClipGrad, ClipNorm, OptimiserChain
1818

19+
###
20+
### one-array functions
21+
###
22+
1923
"""
2024
Optimisers.apply!(rule::RuleType, state, parameters, gradient) -> (state, gradient)
2125
@@ -57,6 +61,10 @@ julia> Optimisers.init(Momentum(), [1.0, 2.0])
5761
"""
5862
init
5963

64+
###
65+
### whole-model functions
66+
###
67+
6068
"""
6169
Optimisers.setup(rule, model) -> tree
6270

src/interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ end
166166
###
167167

168168
"""
169-
@.. x = x + y
169+
@.. x = y + z
170170
171171
Sometimes in-place broadcasting macro, for use in `apply!` rules.
172172
If `maywrite(x)` then it is just `@. x = rhs`, but if not, it becomes `x = @. rhs`.

0 commit comments

Comments
 (0)