Skip to content

Commit 71c5a75

Browse files
authored
Add adjust (#89)
* Optimisers.adjust * move to setup, add tests, etc * reverse adjust signature to match init, apply * don't overload setup, and make adjust its own file * docs * new, simpler, version * doc changes * simplify doc * fix doctest
1 parent cdc64ef commit 71c5a75

File tree

6 files changed

+138
-11
lines changed

6 files changed

+138
-11
lines changed

docs/src/api.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Optimisers.OptimiserChain
3434
Optimisers.setup
3535
Optimisers.update
3636
Optimisers.update!
37+
Optimisers.adjust(::Any, ::Real)
3738
```
3839

3940
Calling `Functors.@functor` on your model's layer types by default causes the
@@ -57,4 +58,5 @@ Optimisers.apply!
5758
Optimisers.init
5859
Optimisers.@..
5960
Optimisers.@lazy
61+
Optimisers.adjust(::AbstractRule, ::Real)
6062
```

docs/src/index.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@ These act on one array of parameters:
88
```julia
99
# Define a container to hold any optimiser specific parameters (if any):
1010
struct DecayDescent{T} <: Optimisers.AbstractRule
11-
η::T
11+
eta::T
1212
end
1313

1414
# Define an `apply!` rule which encodes how the gradients will be used to
1515
# update the parameters:
1616
function Optimisers.apply!(o::DecayDescent, state, x, x̄)
17-
newx̄ = (o.η / state) .*
17+
newx̄ = (o.eta / state) .*
1818
nextstate = state + 1
1919
return nextstate, newx̄
2020
end

src/Optimisers.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ using LinearAlgebra
66
include("interface.jl")
77
export AbstractRule
88

9+
include("adjust.jl")
10+
911
include("destructure.jl")
1012
export destructure
1113

src/adjust.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
2+
"""
3+
Optimisers.adjust(tree, η) -> tree
4+
5+
Alters the state `tree = setup(rule, model)` to change the parameters of the
6+
optimisation rule, without destroying its stored state. Typically used mid-way
7+
through training.
8+
9+
To change just the learning rate, provide a number `η::Real`.
10+
11+
# Example
12+
```jldoctest
13+
julia> m = (vec = rand(Float32, 2), fun = sin);
14+
15+
julia> st = Optimisers.setup(Nesterov(), m) # stored momentum is initialised to zero
16+
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[0.0, 0.0]), fun = nothing)
17+
18+
julia> st, m = Optimisers.update(st, m, (vec = [16, 88], fun = nothing)); # with fake gradient
19+
20+
julia> st
21+
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[-0.016, -0.088]), fun = nothing)
22+
23+
julia> st = Optimisers.adjust(st, 0.123) # change learning rate, stored momentum untouched
24+
(vec = Leaf(Nesterov{Float32}(0.123, 0.9), Float32[-0.016, -0.088]), fun = nothing)
25+
```
26+
27+
To change other parameters, `adjust` also accepts keyword arguments matching the field
28+
names of the optimisation rule's type.
29+
30+
```
31+
julia> fieldnames(Adam)
32+
(:eta, :beta, :epsilon)
33+
34+
julia> st2 = Optimisers.setup(OptimiserChain(ClipGrad(), Adam()), m)
35+
(vec = Leaf(OptimiserChain(ClipGrad{Float32}(10.0), Adam{Float32}(0.001, (0.9, 0.999), 1.19209f-7)), [nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999))]), fun = nothing)
36+
37+
julia> Optimisers.adjust(st2; beta = (0.777, 0.909), delta = 11.1) # delta acts on ClipGrad
38+
(vec = Leaf(OptimiserChain(ClipGrad{Float32}(11.1), Adam{Float32}(0.001, (0.777, 0.909), 1.19209f-7)), [nothing, (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999))]), fun = nothing)
39+
40+
julia> Optimisers.adjust(st; beta = "no such field") # silently ignored!
41+
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[-0.016, -0.088]), fun = nothing)
42+
```
43+
"""
44+
adjust(tree, eta::Real) = map(st -> adjust(st, eta), tree)
45+
adjust(tree; kw...) = map(st -> adjust(st; kw...), tree)
46+
47+
adjust(::Nothing, ::Real) = nothing
48+
adjust(::Nothing; kw...) = nothing
49+
50+
adjust(ℓ::Leaf, eta::Real) = Leaf(adjust(ℓ.rule, eta), ℓ.state)
51+
adjust(ℓ::Leaf; kw...) = Leaf(adjust(ℓ.rule; kw...), ℓ.state)
52+
53+
54+
"""
55+
Optimisers.adjust(rule::RuleType, η::Real) -> rule
56+
57+
If a new optimisation rule has a learning rate which is *not* stored in field `rule.eta`,
58+
then you may should add a method to `adjust`. (But simpler to just use the standard name.)
59+
"""
60+
adjust(r::AbstractRule, eta::Real) = _adjust(r, (; eta))
61+
adjust(r::AbstractRule; kw...) = _adjust(r, NamedTuple(kw))
62+
63+
function _adjust(r::T, nt::NamedTuple) where T <: AbstractRule
64+
isempty(nt) && throw(ArgumentError("adjust must be given something to act on!"))
65+
fs = fieldnames(T)
66+
vals = map(fs) do field
67+
get(nt, field, getfield(r, field))
68+
end
69+
T(vals...) # relies on having the default constructor
70+
end
71+

src/rules.jl

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,16 @@ function apply!(o::RMSProp, state, x, dx)
129129
return (quad, lin), dx′
130130
end
131131

132+
function adjust(r::RMSProp; kw...)
133+
:centred in keys(kw) && throw(ArgumentError("adjust(::RMSProp; centred) is not allowed, as the variants store different states"))
134+
_adjust(r, NamedTuple(kw)) # that's why _adjust exists!
135+
end
136+
132137
function Base.show(io::IO, o::RMSProp)
133-
show(io, typeof(o))
134-
print(io, "(")
135-
join(io, [o.eta, o.rho, o.epsilon], ", ")
136-
print(io, "; centred = ", o.centred, ")")
138+
show(io, typeof(o))
139+
print(io, "(")
140+
join(io, [o.eta, o.rho, o.epsilon], ", ")
141+
print(io, "; centred = ", o.centred, ")")
137142
end
138143

139144

@@ -542,7 +547,7 @@ See also [`ClipNorm`](@ref).
542547
struct ClipGrad{T<:Real} <: AbstractRule
543548
delta::T
544549
end
545-
ClipGrad() = ClipGrad(10f0)
550+
ClipGrad(δ::Integer = 10) = ClipGrad(Float32(δ)) # float is to ensure adjust can change this
546551

547552
init(o::ClipGrad, x::AbstractArray) = nothing
548553

@@ -569,7 +574,7 @@ struct ClipNorm{T<:Real} <: AbstractRule
569574
p::T
570575
throw::Bool
571576
end
572-
ClipNorm= 10f0, p = 2; throw::Bool = true) = ClipNorm{typeof(ω)}(ω, p, throw)
577+
ClipNorm= 10f0, p = 2; throw::Bool = true) = ClipNorm{float(typeof(ω))}(ω, p, throw)
573578

574579
init(o::ClipNorm, x::AbstractArray) = nothing
575580

@@ -595,12 +600,12 @@ This is equivalent to `Descent(1)`.
595600
596601
# Example
597602
```jldoctest
598-
julia> o = OptimiserChain(ClipGrad(1), Descent(0.1));
603+
julia> o = OptimiserChain(ClipGrad(1.0), Descent(0.1));
599604
600605
julia> m = (zeros(3),);
601606
602607
julia> s = Optimisers.setup(o, m)
603-
(Leaf(OptimiserChain(ClipGrad{Int64}(1), Descent{Float64}(0.1)), [nothing, nothing]),)
608+
(Leaf(OptimiserChain(ClipGrad{Float64}(1.0), Descent{Float64}(0.1)), [nothing, nothing]),)
604609
605610
julia> Optimisers.update(s, m, ([0.3, 1, 7],))[2] # clips before discounting
606611
([-0.03, -0.1, -0.1],)
@@ -626,4 +631,7 @@ function Base.show(io::IO, c::OptimiserChain)
626631
print(io, "OptimiserChain(")
627632
join(io, c.opts, ", ")
628633
print(io, ")")
629-
end
634+
end
635+
636+
adjust(ℓ::OptimiserChain, eta::Real) = OptimiserChain(map(opt -> adjust(opt, eta), ℓ.opts)...)
637+
adjust(ℓ::OptimiserChain; kw...) = OptimiserChain(map(opt -> adjust(opt; kw...), ℓ.opts)...)

test/runtests.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,50 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
130130
@test eltype(m4[1]) == Float16 # because of explicit broadcast in subtract!
131131
@test eltype(m4[2]) == Float32
132132
end
133+
134+
@testset "adjusyting parameters" begin
135+
# Simple momentum:
136+
m == ([0.0], sin), γ = Float32[4,3,2])
137+
s = Optimisers.setup(Momentum(0.1, 0.9), m)
138+
s1, m1 = Optimisers.update(s, m, (α = nothing, γ = [1,10,100],))
139+
@test m.γ .- m1.γ [0.1, 1, 10]
140+
@test s1.γ.rule.eta == 0.1
141+
@test s1.γ.state [0.1, 1, 10]
142+
143+
s2 = Optimisers.adjust(s1, 0.2)
144+
@test s2.γ.rule.eta == 0.2
145+
@test s2.γ.rule.rho == 0.9
146+
@test s2.γ.state == s1.γ.state
147+
@test s2.α[1].rule.eta == 0.2
148+
@test s2.α[1].state == s1.α[1].state
149+
150+
s3 = Optimisers.adjust(s1; eta=0.3, rho=0.7)
151+
@test s3.γ.rule.eta == 0.3
152+
@test s3.γ.rule.rho == 0.7
153+
@test s3.γ.state == s1.γ.state
154+
@test s3.α[1].rule.rho == 0.7
155+
156+
_, m3 = Optimisers.update(s3, m, (α = nothing, γ = [1,10,100],))
157+
@test !(m.γ .- m3.γ [1, 10, 100])
158+
159+
@test s1 == Optimisers.adjust(s1, zeta = "this does nothing")
160+
161+
# OptimiserChain
162+
sc = Optimisers.setup(OptimiserChain(ClipGrad(2), Adam()), m)
163+
sc1, mc1 = Optimisers.update(sc, m, (α = nothing, γ = [1,10,100],))
164+
@test sc1.γ.rule.opts[2].eta == 0.001f0
165+
@test sc1.γ.state[2][1] [0.1, 0.2, 0.2]
166+
167+
sc2 = Optimisers.adjust(sc1, 0.2)
168+
@test sc2.γ.rule.opts[1].delta == 2 # unchanged
169+
@test sc2.γ.rule.opts[2].eta === 0.2f0
170+
@test sc2.γ.state[2][1] [0.1, 0.2, 0.2]
171+
172+
sc2 = Optimisers.adjust(sc1; delta = 2.5) # ClipGrad(2) does not store an Int, for this reason
173+
@test sc2.γ.rule.opts[1].delta == 2.5
174+
@test sc2.γ.rule.opts[2].eta === 0.001f0 # unchanged
175+
@test sc2.γ.state[2][1] [0.1, 0.2, 0.2]
176+
end
133177

134178
@testset "forgotten gradient" begin
135179
x = [1.0, 2.0]

0 commit comments

Comments
 (0)