Skip to content

Commit 2399588

Browse files
bors[bot]cossio
andauthored
Merge #1816
1816: ExpDecay start step r=DhairyaLGandhi a=cossio Adds an option to `ExpDecay` which specifies the step at which the exponential decay of the learning rate starts. Fixes #1815. ### PR Checklist - [x] Tests are added - [ ] Entry in NEWS.md - [x] Documentation, if applicable - [ ] API changes require approval from a committer (different from the author, if applicable) Co-authored-by: cossio <j.cossio.diaz@gmail.com>
2 parents fe803a1 + ab371bb commit 2399588

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

src/optimise/optimisers.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ function apply!(o::InvDecay, x, Δ)
594594
end
595595

596596
"""
597-
ExpDecay(η = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4)
597+
ExpDecay(η = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4, start = 1)
598598
599599
Discount the learning rate `η` by the factor `decay` every `decay_step` steps till
600600
a minimum of `clip`.
@@ -606,6 +606,7 @@ a minimum of `clip`.
606606
- `decay_step`: Schedule decay operations by setting the number of steps between
607607
two decay operations.
608608
- `clip`: Minimum value of learning rate.
609+
- 'start': Step at which the decay starts.
609610
610611
611612
See also the [Scheduling Optimisers](@ref) section of the docs
@@ -624,16 +625,17 @@ mutable struct ExpDecay <: AbstractOptimiser
624625
decay::Float64
625626
step::Int64
626627
clip::Float64
628+
start::Int64
627629
current::IdDict
628630
end
629631

630-
ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4) =
631-
ExpDecay(opt, decay, decay_step, clip, IdDict())
632+
ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4, start = 0) =
633+
ExpDecay(opt, decay, decay_step, clip, start, IdDict())
632634

633635
function apply!(o::ExpDecay, x, Δ)
634-
η, s, decay = o.eta, o.step, o.decay
636+
η, s, decay, start = o.eta, o.step, o.decay, o.start
635637
n = o.current[x] = get(o.current, x, 0) + 1
636-
if o.current[x]%s == 0 && count(x -> x%s == 0, values(o.current)) == 1
638+
if n > start && n % s == 0 && count(x -> x > start && x % s == 0, values(o.current)) == 1
637639
η = max* decay, o.clip)
638640
o.eta = η
639641
end

test/optimise.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,16 @@ end
9898
@test eta_actual == eta_expected
9999
end
100100

101+
@testset "starting step" begin
102+
start = 4
103+
o = ExpDecay(0.2, 0.5, 1, 1e-3, start)
104+
p = [0.0]
105+
steps = 1:8
106+
eta_expected = @. max(o.eta * 0.5 ^ max(steps - start, 0), o.clip)
107+
eta_actual = [Optimise.apply!(o, p, [1.0])[1] for _ in steps]
108+
@test eta_actual == eta_expected
109+
end
110+
101111
w = randn(10, 10)
102112
o = ExpDecay(0.1, 0.1, 1000, 1e-4)
103113
w1 = randn(10,10)

0 commit comments

Comments
 (0)