1
1
using Flux
2
2
using MacroTools: @forward
3
3
4
+ abstract type AbstractOptimiser end
5
+
4
6
const ϵ = 1e-8
5
7
6
8
# TODO : should use weak refs
30
32
Flux.Optimise.update!(opt, ps, gs)
31
33
```
32
34
"""
33
- mutable struct Descent
35
+ mutable struct Descent <: AbstractOptimiser
34
36
eta:: Float64
35
37
end
36
38
@@ -58,7 +60,7 @@ opt = Momentum()
58
60
opt = Momentum(0.01, 0.99)
59
61
```
60
62
"""
61
- mutable struct Momentum
63
+ mutable struct Momentum <: AbstractOptimiser
62
64
eta:: Float64
63
65
rho:: Float64
64
66
velocity:: IdDict
@@ -91,7 +93,7 @@ opt = Nesterov()
91
93
opt = Nesterov(0.003, 0.95)
92
94
```
93
95
"""
94
- mutable struct Nesterov
96
+ mutable struct Nesterov <: AbstractOptimiser
95
97
eta:: Float64
96
98
rho:: Float64
97
99
velocity:: IdDict
@@ -128,7 +130,7 @@ opt = RMSProp()
128
130
opt = RMSProp(0.002, 0.95)
129
131
```
130
132
"""
131
- mutable struct RMSProp
133
+ mutable struct RMSProp <: AbstractOptimiser
132
134
eta:: Float64
133
135
rho:: Float64
134
136
acc:: IdDict
@@ -161,7 +163,7 @@ opt = ADAM()
161
163
opt = ADAM(0.001, (0.9, 0.8))
162
164
```
163
165
"""
164
- mutable struct ADAM
166
+ mutable struct ADAM <: AbstractOptimiser
165
167
eta:: Float64
166
168
beta:: Tuple{Float64,Float64}
167
169
state:: IdDict
@@ -202,7 +204,7 @@ opt = RADAM()
202
204
opt = RADAM(0.001, (0.9, 0.8))
203
205
```
204
206
"""
205
- mutable struct RADAM
207
+ mutable struct RADAM <: AbstractOptimiser
206
208
eta:: Float64
207
209
beta:: Tuple{Float64,Float64}
208
210
state:: IdDict
@@ -251,7 +253,7 @@ opt = AdaMax()
251
253
opt = AdaMax(0.001, (0.9, 0.995))
252
254
```
253
255
"""
254
- mutable struct AdaMax
256
+ mutable struct AdaMax <: AbstractOptimiser
255
257
eta:: Float64
256
258
beta:: Tuple{Float64,Float64}
257
259
state:: IdDict
@@ -293,7 +295,7 @@ opt = OADAM()
293
295
opt = OADAM(0.001, (0.9, 0.995))
294
296
```
295
297
"""
296
- mutable struct OADAM
298
+ mutable struct OADAM <: AbstractOptimiser
297
299
eta:: Float64
298
300
beta:: Tuple{Float64,Float64}
299
301
state:: IdDict
@@ -336,7 +338,7 @@ opt = ADAGrad()
336
338
opt = ADAGrad(0.001)
337
339
```
338
340
"""
339
- mutable struct ADAGrad
341
+ mutable struct ADAGrad <: AbstractOptimiser
340
342
eta:: Float64
341
343
acc:: IdDict
342
344
end
@@ -367,7 +369,7 @@ opt = ADADelta()
367
369
opt = ADADelta(0.89)
368
370
```
369
371
"""
370
- mutable struct ADADelta
372
+ mutable struct ADADelta <: AbstractOptimiser
371
373
rho:: Float64
372
374
state:: IdDict
373
375
end
@@ -404,7 +406,7 @@ opt = AMSGrad()
404
406
opt = AMSGrad(0.001, (0.89, 0.995))
405
407
```
406
408
"""
407
- mutable struct AMSGrad
409
+ mutable struct AMSGrad <: AbstractOptimiser
408
410
eta:: Float64
409
411
beta:: Tuple{Float64, Float64}
410
412
state:: IdDict
@@ -444,7 +446,7 @@ opt = NADAM()
444
446
opt = NADAM(0.002, (0.89, 0.995))
445
447
```
446
448
"""
447
- mutable struct NADAM
449
+ mutable struct NADAM <: AbstractOptimiser
448
450
eta:: Float64
449
451
beta:: Tuple{Float64, Float64}
450
452
state:: IdDict
@@ -537,7 +539,7 @@ Combine several optimisers into one; each optimiser produces a modified gradient
537
539
that will be fed into the next, and this is finally applied to the parameter as
538
540
usual.
539
541
"""
540
- mutable struct Optimiser
542
+ mutable struct Optimiser <: AbstractOptimiser
541
543
os:: Vector{Any}
542
544
end
543
545
@@ -567,7 +569,7 @@ The wrapped optimiser's step size is not modified.
567
569
Optimiser(InvDecay(..), Opt(..))
568
570
```
569
571
"""
570
- mutable struct InvDecay
572
+ mutable struct InvDecay <: AbstractOptimiser
571
573
gamma:: Float64
572
574
state:: IdDict
573
575
end
@@ -604,7 +606,7 @@ Optimiser(ExpDecay(..), Opt(..))
604
606
opt = Optimiser(ExpDecay(), ADAM())
605
607
```
606
608
"""
607
- mutable struct ExpDecay
609
+ mutable struct ExpDecay <: AbstractOptimiser
608
610
eta:: Float64
609
611
decay:: Float64
610
612
step:: Int64
@@ -632,7 +634,7 @@ Decay weights by `wd`.
632
634
# Parameters
633
635
- Weight decay (`wd`)
634
636
"""
635
- mutable struct WeightDecay
637
+ mutable struct WeightDecay <: AbstractOptimiser
636
638
wd:: Real
637
639
end
638
640
648
650
649
651
Clip gradients when their absolute value exceeds `thresh`.
650
652
"""
651
- mutable struct ClipValue{T}
653
+ mutable struct ClipValue{T} <: AbstractOptimiser
652
654
thresh:: T
653
655
end
654
656
@@ -659,7 +661,7 @@ apply!(o::ClipValue, x, Δ) = clamp!(Δ, -o.thresh, o.thresh)
659
661
660
662
Clip gradients when their L2 norm exceeds `thresh`.
661
663
"""
662
- mutable struct ClipNorm{T}
664
+ mutable struct ClipNorm{T} <: AbstractOptimiser
663
665
thresh:: T
664
666
end
665
667
0 commit comments