Skip to content

Commit 9c12e5d

Browse files
mcabbottToucheSir
andauthored
Allow shared parameters, take III (#106)
* allow shared parameters, take III Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com> * one more dict to allow artificial ties * a tidier idea, just replace _default_walk * add a LeafCache type, to make fmap ignore () singleton * remove leaf.frozen field * eager accumulation * give up on customising fmap & write the recursion, add evil tests * add ismutable check * docs etc * fix doctests * group the tests Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
1 parent bf54f76 commit 9c12e5d

File tree

6 files changed

+352
-93
lines changed

6 files changed

+352
-93
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Optimisers"
22
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
33
authors = ["Mike J Innes <mike.j.innes@gmail.com>"]
4-
version = "0.2.9"
4+
version = "0.2.10"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1212

1313
[compat]
1414
ChainRulesCore = "1"
15-
Functors = "0.2.8, 0.3"
15+
Functors = "0.3"
1616
Zygote = "0.6.40"
1717
julia = "1.6"
1818

docs/src/index.md

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Optimisers.jl
22

3-
## Defining an optimisation rule
3+
## An optimisation rule
44

55
A new optimiser must overload two functions, [`apply!`](@ref) and [`init`](@ref).
66
These act on one array of parameters:
@@ -60,18 +60,18 @@ Notice that a completely new instance of the model is returned. Internally, this
6060
is handled by [Functors.jl](https://fluxml.ai/Functors.jl), where we do a walk over the
6161
tree formed by the model and update the parameters using the gradients.
6262

63+
There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state,
64+
but is free to mutate arrays within the old one for efficiency.
65+
The method of `apply!` for each rule is likewise free to mutate arrays within its state;
66+
they are defensively copied when this rule is used with `update`.
67+
6368
Optimisers.jl does not depend on any one automatic differentiation package,
6469
but for now the most likely source of gradients is [Zygote.jl](https://fluxml.ai/Zygote.jl).
6570
Note that `update` always wants the gradient from Zygote's "explicit" mode, as shown above.
6671
This `∇model` is another tree structure, rather than the dictionary-like object from
6772
Zygote's "implicit" mode `gradient(() -> loss(...), Flux.params(model))` -- see
6873
[Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) for more about this difference.
6974

70-
There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state,
71-
but is free to mutate arrays within the old one for efficiency.
72-
The method of `apply!` you write is likewise free to mutate arrays within its state;
73-
they are defensively copied when this rule is used with `update`.
74-
7575
## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)
7676

7777
The main design difference of Lux is that the tree of parameters is separate from
@@ -110,6 +110,57 @@ Besides the parameters stored in `params` and gradually optimised, any other mod
110110
is stored in `lux_state`. For simplicity this example does not show how to propagate the
111111
updated `lux_state` to the next iteration, see Lux's documentation.
112112

113+
## Non-`trainable` Parameters
114+
115+
Optimisers.jl uses [Functors.jl](https://fluxml.ai/Functors.jl) to walk the `struct`s
116+
making up the model, for which they must be annotated `@functor Type`.
117+
By default optimisation will alter all [`isnumeric`](@ref) arrays.
118+
119+
If some arrays of a particular layer should not be treated this way,
120+
you can define a method for [`trainable`](@ref)
121+
122+
```julia
123+
struct Layer{T}
124+
alpha::T
125+
beta::T
126+
length::Int
127+
end
128+
Layer(n::Int) = Layer(randn(n), zeros(n), n)
129+
130+
Functors.@functor Layer
131+
132+
# Both array fields will be, for example, moved to the GPU:
133+
Functors.children(Layer(3)) # (alpha = [...], beta = [...], length)
134+
135+
Optimisers.trainable(x::Layer) = (; alpha = x.alpha) # must be a subset of chidlren
136+
137+
# Only the first field will be optimised:
138+
st = Optimisers.setup(DecayDescent(0.1), Layer(3))
139+
```
140+
141+
## Tied Parameters
142+
143+
If the same array appears twice (or more) in the model, [Functors.jl](https://fluxml.ai/Functors.jl) should recognise this.
144+
Within Optimisers.jl, `setup` will initialise once, and use the same `Leaf` for both parameters.
145+
Then `update` will accumulate the gradient from both, and the updated model returned will have the tie maintained.
146+
147+
```julia
148+
using Flux, Optimisers
149+
150+
enc = Chain(Dense(40 => 20, tanh), Dense(20 => 10));
151+
dec = Chain(Dense(enc[1].weight', true, tanh), Dense(enc[2].weight', true, tanh));
152+
model = Chain(; enc, dec)
153+
154+
st = Optimisers.setup(Optimisers.Adam(), model);
155+
156+
st.layers.enc.layers[1].weight === st.layers.dec.layers[1].weight.parent # true
157+
```
158+
159+
This identification relies on `===`, and will work for ordinary `Array`s and `CuArray`s.
160+
It will not at present work for `reshape`d arrays, nor for immutable arrays such as those
161+
from StaticArrays.jl.
162+
163+
113164
## Obtaining a flat parameter vector
114165

115166
Instead of a nested tree-like structure, sometimes is is convenient to have all the
@@ -143,10 +194,11 @@ st, flat = Optimisers.update(st, flat, ∇flat)
143194
```
144195

145196
Here `flat` contains only the 283 trainable parameters, while the non-trainable
146-
ones are preserved inside `re`.
197+
ones are preserved inside `re`, an object of type `Restructure`.
147198
When defining new layers, these can be specified if necessary by overloading [`trainable`](@ref).
148199
By default, all numeric arrays visible to [Functors.jl](https://github.com/FluxML/Functors.jl)
149200
are assumed to contain trainable parameters.
201+
Tied parameters (arrays appearing in different layers) are included only once in `flat`.
150202

151203
Lux stores only the trainable parameters in `params`.
152204
This can also be flattened to a plain `Vector` in the same way:

src/Optimisers.jl

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module Optimisers
22

3-
using Functors: functor, fmap, isleaf
3+
using Functors: functor, fmap, isleaf, @functor, fmapstructure, children
44
using LinearAlgebra
55

66
include("interface.jl")
@@ -16,6 +16,10 @@ export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
1616
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
1717
WeightDecay, ClipGrad, ClipNorm, OptimiserChain
1818

19+
###
20+
### one-array functions
21+
###
22+
1923
"""
2024
Optimisers.apply!(rule::RuleType, state, parameters, gradient) -> (state, gradient)
2125
@@ -57,6 +61,10 @@ julia> Optimisers.init(Momentum(), [1.0, 2.0])
5761
"""
5862
init
5963

64+
###
65+
### whole-model functions
66+
###
67+
6068
"""
6169
Optimisers.setup(rule, model) -> tree
6270
@@ -69,7 +77,7 @@ or [`update!`](@ref).
6977
julia> m = (x = rand(3), y = (true, false), z = tanh);
7078
7179
julia> Optimisers.setup(Momentum(), m) # same field names as m
72-
(x = Leaf(Momentum{Float32}(0.01, 0.9), [0.0, 0.0, 0.0]), y = (nothing, nothing), z = nothing)
80+
(x = Leaf(Momentum{Float32}(0.01, 0.9), [0.0, 0.0, 0.0]), y = ((), ()), z = ())
7381
```
7482
7583
The recursion into structures uses Functors.jl, and any new `struct`s containing parameters
@@ -82,15 +90,15 @@ julia> struct Layer; mat; fun; end
8290
julia> model = (lay = Layer([1 2; 3 4f0], sin), vec = [5, 6f0]);
8391
8492
julia> Optimisers.setup(Momentum(), model) # new struct is by default ignored
85-
(lay = nothing, vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))
93+
(lay = (), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))
8694
8795
julia> destructure(model)
8896
(Float32[5.0, 6.0], Restructure(NamedTuple, ..., 2))
8997
9098
julia> using Functors; @functor Layer # annotate this type as containing parameters
9199
92100
julia> Optimisers.setup(Momentum(), model)
93-
(lay = (mat = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0 0.0; 0.0 0.0]), fun = nothing), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))
101+
(lay = (mat = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0 0.0; 0.0 0.0]), fun = ()), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))
94102
95103
julia> destructure(model)
96104
(Float32[1.0, 3.0, 2.0, 4.0, 5.0, 6.0], Restructure(NamedTuple, ..., 6))
@@ -112,12 +120,12 @@ See also [`update!`](@ref), which will be faster for models of ordinary `Array`s
112120
julia> m = (x = Float32[1,2,3], y = tanh);
113121
114122
julia> t = Optimisers.setup(Descent(0.1f0), m)
115-
(x = Leaf(Descent{Float32}(0.1), nothing), y = nothing)
123+
(x = Leaf(Descent{Float32}(0.1), nothing), y = ())
116124
117125
julia> g = (x = [1,1,1], y = nothing); # fake gradient
118126
119127
julia> Optimisers.update(t, m, g)
120-
((x = Leaf(Descent{Float32}(0.1), nothing), y = nothing), (x = Float32[0.9, 1.9, 2.9], y = tanh))
128+
((x = Leaf(Descent{Float32}(0.1), nothing), y = ()), (x = Float32[0.9, 1.9, 2.9], y = tanh))
121129
```
122130
"""
123131
update
@@ -157,8 +165,8 @@ true
157165
julia> m # original should be discarded, may be mutated but no guarantee
158166
(x = Float32[0.6666666, 1.5333333], y = Float32[4.0, 5.0])
159167
160-
julia> t # original state should likewise be discarded
161-
(x = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.333333, 0.466667]), y = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.0, 0.0]))
168+
julia> t == t2 # original state is in fact guaranteed to be mutated
169+
true
162170
```
163171
"""
164172
update!

src/adjust.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@ To change just the learning rate, provide a number `η::Real`.
1313
julia> m = (vec = rand(Float32, 2), fun = sin);
1414
1515
julia> st = Optimisers.setup(Nesterov(), m) # stored momentum is initialised to zero
16-
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[0.0, 0.0]), fun = nothing)
16+
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[0.0, 0.0]), fun = ())
1717
1818
julia> st, m = Optimisers.update(st, m, (vec = [16, 88], fun = nothing)); # with fake gradient
1919
2020
julia> st
21-
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[-0.016, -0.088]), fun = nothing)
21+
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[-0.016, -0.088]), fun = ())
2222
2323
julia> st = Optimisers.adjust(st, 0.123) # change learning rate, stored momentum untouched
24-
(vec = Leaf(Nesterov{Float32}(0.123, 0.9), Float32[-0.016, -0.088]), fun = nothing)
24+
(vec = Leaf(Nesterov{Float32}(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())
2525
```
2626
2727
To change other parameters, `adjust` also accepts keyword arguments matching the field

src/interface.jl

Lines changed: 95 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,120 @@
11

2-
using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero
2+
using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero, ZeroTangent
33
base(dx::Tangent) = backing(canonicalize(dx))
44
base(dx) = dx
55
const Zero = Union{Nothing, AbstractZero} # Union{Zygote, Diffractor}
66

77
abstract type AbstractRule end
88

9-
struct Leaf{R,S}
9+
###
10+
### setup
11+
###
12+
13+
mutable struct Leaf{R,S} # mutable so that its identity encodes parameter sharing
1014
rule::R
1115
state::S
1216
end
1317

14-
function setup(rule, x; seen = Base.IdSet())
15-
rule isa AbstractRule || Base.depwarn("In future, all optimisation rules should be <: AbstractRule", :setup)
18+
@functor Leaf
19+
20+
Base.:(==)(a::Leaf, b::Leaf) = children(a) == children(b)
21+
22+
function setup(rule::AbstractRule, model)
23+
cache = IdDict()
24+
tree = _setup(rule, model; cache)
25+
isempty(cache) && @warn "setup found no trainable parameters in this model"
26+
tree
27+
end
28+
29+
# _setup is almost fmapstructure, but needs a _trainable_walk, and a cache which ignores numbers etc.
30+
function _setup(rule, x; cache)
31+
haskey(cache, x) && return cache[x]
1632
if isnumeric(x)
17-
x in seen && throw(ArgumentError("Optimisers.jl does not at present handle tied weights, sorry."))
18-
isbits(x) || push!(seen, x)
19-
return Leaf(rule, init(rule, x))
20-
elseif isleaf(x)
21-
return nothing
33+
= Leaf(rule, init(rule, x))
34+
if isbits(x)
35+
cache[nothing] = nothing # just to disable the warning
36+
37+
else
38+
cache[x] =
39+
end
2240
else
23-
return map(xᵢ -> setup(rule, xᵢ; seen), _trainable(x))
41+
map(xᵢ -> _setup(rule, xᵢ; cache), _trainable(x))
2442
end
2543
end
2644

27-
subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄)
45+
function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type!
46+
ioc = IOContext(io, :compact => true)
47+
print(ioc, "Leaf(", ℓ.rule, ", ")
48+
show(ioc, ℓ.state)
49+
print(ioc, ")")
50+
end
2851

29-
update!(::Nothing, x, ::Zero, ::Zero...) = nothing, x
30-
update!(::Nothing, x, x̄s...) = nothing, x
52+
###
53+
### update
54+
###
3155

32-
update!(ℓ::Leaf, x, ::Zero, ::Zero...) = ℓ, x
33-
function update!(ℓ::Leaf, x, x̄s...)
34-
s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, base.(x̄s)...)
35-
Leaf(ℓ.rule, s′), subtract!(x, x̄′)
56+
function update(tree, model, grad, higher...)
57+
t′ = fmap(copy, tree; exclude = maywrite) # walks inside Leaf
58+
x′ = fmap(copy, model; exclude = maywrite)
59+
update!(t′, x′, grad, higher...)
3660
end
3761

38-
update!(tree, x, ::Zero, ::Zero...) = tree, x
39-
function update!(tree, x, x̄s...)
40-
x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s)
41-
x′, re = functor(typeof(x), x)
42-
xtree = map((stᵢ, xᵢ, x̄sᵢ...) -> update!(stᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
43-
map(first, xtree), re(map(last, xtree))
62+
function update!(tree, model, grad, higher...)
63+
# First walk is to accumulate the gradient. This recursion visits every copy of
64+
# shared leaves, but stops when branches are absent from the gradient:
65+
grads = IdDict{Leaf, Any}()
66+
_grads!(grads, tree, model, grad, higher...)
67+
# Second walk is to update the model. The params cache indexed by (tree,x),
68+
# so that identified Leafs can tie isbits parameters, but setup won't do that for you:
69+
newmodel = _update!(tree, model; grads, params = IdDict())
70+
tree, newmodel # note that tree is guaranteed to be updated. Also that it's not necc a tree.
71+
end
72+
73+
function _update!(tree, x; grads, params)
74+
haskey(params, (tree,x)) && return params[(tree,x)]
75+
isbits(tree) && return x # means () is not cached, and also (((),),)
76+
x′, re = functor(x)
77+
x′′ = re(map((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′))
78+
if ismutable(x′′)
79+
params[(tree,x)] = x′′
80+
else # no ties to preserve between immutable structs, right?
81+
x′′
82+
end
4483
end
84+
function _update!(ℓ::Leaf, x; grads, params)
85+
haskey(params, (ℓ,x)) && return params[(ℓ,x)]
86+
params[(ℓ,x)] = if haskey(grads, ℓ)
87+
.state, x̄′ = apply!(ℓ.rule, ℓ.state, x, grads[ℓ]...)
88+
subtract!(x, x̄′)
89+
else
90+
x # no gradient seen
91+
end
92+
end
93+
94+
subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄)
4595

46-
function update(tree, x, x̄s...)
47-
t′ = fmap(copy, tree; exclude = maywrite)
48-
x′ = fmap(copy, x; exclude = maywrite)
49-
update!(t′, x′, x̄s...)
96+
_grads!(dict::IdDict, ℓ::Leaf, x, ::Zero...) = nothing
97+
function _grads!(dict::IdDict, ℓ::Leaf, x, x̄s...)
98+
x̄s₀ = get(dict, ℓ, map(_ -> ZeroTangent(), x̄s))
99+
dict[ℓ] = map(+, x̄s, x̄s₀) # adding Zero should be free. Lazy accumulation broadcasted(+, x̄, x̄₀) also possible.
100+
nothing
101+
end
102+
_grads!(dict::IdDict, t, x, ::Zero...) = nothing
103+
function _grads!(dict::IdDict, tree, x, x̄s...)
104+
# The only reason _grads! takes model is that functor(typeof(x), base(x̄)) may differ from
105+
# functor(typeof(tree), base(x̄)), for things like Transpose
106+
x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s)
107+
x′, _ = functor(typeof(x), x)
108+
foreach((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
50109
end
51110

52111
# default all rules to first order calls
53112
apply!(o, state, x, dx, dx2, dxs...) = apply!(o, state, x, dx)
54113

114+
###
115+
### sources of truth
116+
###
117+
55118
"""
56119
isnumeric(x) -> Bool
57120
@@ -98,8 +161,12 @@ function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tu
98161
map(c -> c in tr ? c : nothing, ch)
99162
end
100163

164+
###
165+
### rule definition helpers
166+
###
167+
101168
"""
102-
@.. x = x + y
169+
@.. x = y + z
103170
104171
Sometimes in-place broadcasting macro, for use in `apply!` rules.
105172
If `maywrite(x)` then it is just `@. x = rhs`, but if not, it becomes `x = @. rhs`.
@@ -135,11 +202,3 @@ Broadcast.materialize(x::Lazy) = Broadcast.instantiate(x.bc)
135202

136203
onevalue::T, x::AbstractArray{T}) where T = map(_ -> λ, x)
137204
onevalue(λ, x::AbstractArray{T}) where T = onevalue(convert(float(T), λ), x)
138-
139-
function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type!
140-
ioc = IOContext(io, :compact => true)
141-
print(ioc, "Leaf(", ℓ.rule, ", ")
142-
show(ioc, ℓ.state)
143-
print(io, ")")
144-
end
145-

0 commit comments

Comments
 (0)