Skip to content

Commit 06970a5

Browse files
Merge pull request #1325 from DhairyaLGandhi/dg/absopt
Add AbstractOptimiser type
2 parents 87d546c + 5ac0176 commit 06970a5

File tree

1 file changed

+20
-18
lines changed

1 file changed

+20
-18
lines changed

src/optimise/optimisers.jl

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using Flux
22
using MacroTools: @forward
33

4+
abstract type AbstractOptimiser end
5+
46
const ϵ = 1e-8
57

68
# TODO: should use weak refs
@@ -30,7 +32,7 @@ end
3032
Flux.Optimise.update!(opt, ps, gs)
3133
```
3234
"""
33-
mutable struct Descent
35+
mutable struct Descent <: AbstractOptimiser
3436
eta::Float64
3537
end
3638

@@ -58,7 +60,7 @@ opt = Momentum()
5860
opt = Momentum(0.01, 0.99)
5961
```
6062
"""
61-
mutable struct Momentum
63+
mutable struct Momentum <: AbstractOptimiser
6264
eta::Float64
6365
rho::Float64
6466
velocity::IdDict
@@ -91,7 +93,7 @@ opt = Nesterov()
9193
opt = Nesterov(0.003, 0.95)
9294
```
9395
"""
94-
mutable struct Nesterov
96+
mutable struct Nesterov <: AbstractOptimiser
9597
eta::Float64
9698
rho::Float64
9799
velocity::IdDict
@@ -128,7 +130,7 @@ opt = RMSProp()
128130
opt = RMSProp(0.002, 0.95)
129131
```
130132
"""
131-
mutable struct RMSProp
133+
mutable struct RMSProp <: AbstractOptimiser
132134
eta::Float64
133135
rho::Float64
134136
acc::IdDict
@@ -161,7 +163,7 @@ opt = ADAM()
161163
opt = ADAM(0.001, (0.9, 0.8))
162164
```
163165
"""
164-
mutable struct ADAM
166+
mutable struct ADAM <: AbstractOptimiser
165167
eta::Float64
166168
beta::Tuple{Float64,Float64}
167169
state::IdDict
@@ -202,7 +204,7 @@ opt = RADAM()
202204
opt = RADAM(0.001, (0.9, 0.8))
203205
```
204206
"""
205-
mutable struct RADAM
207+
mutable struct RADAM <: AbstractOptimiser
206208
eta::Float64
207209
beta::Tuple{Float64,Float64}
208210
state::IdDict
@@ -251,7 +253,7 @@ opt = AdaMax()
251253
opt = AdaMax(0.001, (0.9, 0.995))
252254
```
253255
"""
254-
mutable struct AdaMax
256+
mutable struct AdaMax <: AbstractOptimiser
255257
eta::Float64
256258
beta::Tuple{Float64,Float64}
257259
state::IdDict
@@ -293,7 +295,7 @@ opt = OADAM()
293295
opt = OADAM(0.001, (0.9, 0.995))
294296
```
295297
"""
296-
mutable struct OADAM
298+
mutable struct OADAM <: AbstractOptimiser
297299
eta::Float64
298300
beta::Tuple{Float64,Float64}
299301
state::IdDict
@@ -336,7 +338,7 @@ opt = ADAGrad()
336338
opt = ADAGrad(0.001)
337339
```
338340
"""
339-
mutable struct ADAGrad
341+
mutable struct ADAGrad <: AbstractOptimiser
340342
eta::Float64
341343
acc::IdDict
342344
end
@@ -367,7 +369,7 @@ opt = ADADelta()
367369
opt = ADADelta(0.89)
368370
```
369371
"""
370-
mutable struct ADADelta
372+
mutable struct ADADelta <: AbstractOptimiser
371373
rho::Float64
372374
state::IdDict
373375
end
@@ -404,7 +406,7 @@ opt = AMSGrad()
404406
opt = AMSGrad(0.001, (0.89, 0.995))
405407
```
406408
"""
407-
mutable struct AMSGrad
409+
mutable struct AMSGrad <: AbstractOptimiser
408410
eta::Float64
409411
beta::Tuple{Float64, Float64}
410412
state::IdDict
@@ -444,7 +446,7 @@ opt = NADAM()
444446
opt = NADAM(0.002, (0.89, 0.995))
445447
```
446448
"""
447-
mutable struct NADAM
449+
mutable struct NADAM <: AbstractOptimiser
448450
eta::Float64
449451
beta::Tuple{Float64, Float64}
450452
state::IdDict
@@ -537,7 +539,7 @@ Combine several optimisers into one; each optimiser produces a modified gradient
537539
that will be fed into the next, and this is finally applied to the parameter as
538540
usual.
539541
"""
540-
mutable struct Optimiser
542+
mutable struct Optimiser <: AbstractOptimiser
541543
os::Vector{Any}
542544
end
543545

@@ -567,7 +569,7 @@ The wrapped optimiser's step size is not modified.
567569
Optimiser(InvDecay(..), Opt(..))
568570
```
569571
"""
570-
mutable struct InvDecay
572+
mutable struct InvDecay <: AbstractOptimiser
571573
gamma::Float64
572574
state::IdDict
573575
end
@@ -604,7 +606,7 @@ Optimiser(ExpDecay(..), Opt(..))
604606
opt = Optimiser(ExpDecay(), ADAM())
605607
```
606608
"""
607-
mutable struct ExpDecay
609+
mutable struct ExpDecay <: AbstractOptimiser
608610
eta::Float64
609611
decay::Float64
610612
step::Int64
@@ -632,7 +634,7 @@ Decay weights by `wd`.
632634
# Parameters
633635
- Weight decay (`wd`)
634636
"""
635-
mutable struct WeightDecay
637+
mutable struct WeightDecay <: AbstractOptimiser
636638
wd::Real
637639
end
638640

@@ -648,7 +650,7 @@ end
648650
649651
Clip gradients when their absolute value exceeds `thresh`.
650652
"""
651-
mutable struct ClipValue{T}
653+
mutable struct ClipValue{T} <: AbstractOptimiser
652654
thresh::T
653655
end
654656

@@ -659,7 +661,7 @@ apply!(o::ClipValue, x, Δ) = clamp!(Δ, -o.thresh, o.thresh)
659661
660662
Clip gradients when their L2 norm exceeds `thresh`.
661663
"""
662-
mutable struct ClipNorm{T}
664+
mutable struct ClipNorm{T} <: AbstractOptimiser
663665
thresh::T
664666
end
665667

0 commit comments

Comments
 (0)