8
8
9
9
"""
10
10
Descent(η = 1f-1)
11
+ Descent(; eta)
11
12
12
13
Classic gradient descent optimiser with learning rate `η`.
13
14
For each parameter `p` and its gradient `dp`, this runs `p -= η*dp`.
14
15
15
16
# Parameters
16
- - Learning rate (`η`): Amount by which gradients are discounted before updating
17
+ - Learning rate (`η == eta `): Amount by which gradients are discounted before updating
17
18
the weights.
18
19
"""
19
20
struct Descent{T} <: AbstractRule
20
21
eta:: T
21
22
end
22
- Descent () = Descent ( 1f-1 )
23
+ Descent (; eta = 1f-1 ) = Descent (eta )
23
24
24
25
init (o:: Descent , x:: AbstractArray ) = nothing
25
26
37
38
38
39
"""
39
40
Momentum(η = 0.01, ρ = 0.9)
41
+ Momentum(; [eta, rho])
40
42
41
43
Gradient descent optimizer with learning rate `η` and momentum `ρ`.
42
44
43
45
# Parameters
44
- - Learning rate (`η`): Amount by which gradients are discounted before updating
46
+ - Learning rate (`η == eta `): Amount by which gradients are discounted before updating
45
47
the weights.
46
- - Momentum (`ρ`): Controls the acceleration of gradient descent in the
48
+ - Momentum (`ρ == rho `): Controls the acceleration of gradient descent in the
47
49
prominent direction, in effect dampening oscillations.
48
50
"""
49
51
@def struct Momentum <: AbstractRule
89
91
90
92
"""
91
93
RMSProp(η = 0.001, ρ = 0.9, ϵ = 1e-8; centred = false)
94
+ RMSProp(; [eta, rho, epsilon, centred])
92
95
93
96
Optimizer using the
94
97
[RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
@@ -99,11 +102,11 @@ generally don't need tuning.
99
102
gradients by an estimate their variance, instead of their second moment.
100
103
101
104
# Parameters
102
- - Learning rate (`η`): Amount by which gradients are discounted before updating
105
+ - Learning rate (`η == eta `): Amount by which gradients are discounted before updating
103
106
the weights.
104
- - Momentum (`ρ`): Controls the acceleration of gradient descent in the
107
+ - Momentum (`ρ == rho `): Controls the acceleration of gradient descent in the
105
108
prominent direction, in effect dampening oscillations.
106
- - Machine epsilon (`ϵ`): Constant to prevent division by zero
109
+ - Machine epsilon (`ϵ == epsilon `): Constant to prevent division by zero
107
110
(no need to change default)
108
111
- Keyword `centred` (or `centered`): Indicates whether to use centred variant
109
112
of the algorithm.
@@ -115,10 +118,11 @@ struct RMSProp <: AbstractRule
115
118
centred:: Bool
116
119
end
117
120
118
- function RMSProp (η = 0.001 , ρ = 0.9 , ϵ = 1e-8 ; centred:: Bool = false , centered:: Bool = false )
121
+ function RMSProp (η, ρ = 0.9 , ϵ = 1e-8 ; centred:: Bool = false , centered:: Bool = false )
119
122
η < 0 && throw (DomainError (η, " the learning rate cannot be negative" ))
120
123
RMSProp (η, ρ, ϵ, centred | centered)
121
124
end
125
+ RMSProp (; eta = 0.001 , rho = 0.9 , epsilon = 1e-8 , kw... ) = RMSProp (eta, rho, epsilon; kw... )
122
126
123
127
init (o:: RMSProp , x:: AbstractArray ) = (zero (x), o. centred ? zero (x) : false )
124
128
@@ -488,22 +492,27 @@ end
488
492
489
493
"""
490
494
AdamW(η = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8)
495
+ AdamW(; [eta, beta, lambda, epsilon])
491
496
492
497
[AdamW](https://arxiv.org/abs/1711.05101) is a variant of Adam fixing (as in repairing) its
493
498
weight decay regularization.
499
+ Implemented as an [`OptimiserChain`](@ref) of [`Adam`](@ref) and [`WeightDecay`](@ref)`.
494
500
495
501
# Parameters
496
- - Learning rate (`η`): Amount by which gradients are discounted before updating
502
+ - Learning rate (`η == eta `): Amount by which gradients are discounted before updating
497
503
the weights.
498
- - Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
504
+ - Decay of momentums (`β::Tuple == beta `): Exponential decay for the first (β1) and the
499
505
second (β2) momentum estimate.
500
- - Weight decay (`λ`): Controls the strength of ``L_2`` regularisation.
501
- - Machine epsilon (`ϵ`): Constant to prevent division by zero
506
+ - Weight decay (`λ == lambda `): Controls the strength of ``L_2`` regularisation.
507
+ - Machine epsilon (`ϵ == epsilon `): Constant to prevent division by zero
502
508
(no need to change default)
503
509
"""
504
- AdamW (η = 0.001 , β = (0.9 , 0.999 ), λ = 0 , ϵ = 1e-8 ) =
510
+ AdamW (η, β = (0.9 , 0.999 ), λ = 0. 0 , ϵ = 1e-8 ) =
505
511
OptimiserChain (Adam (η, β, ϵ), WeightDecay (λ))
506
512
513
+ AdamW (; eta = 0.001 , beta = (0.9 , 0.999 ), lambda = 0 , epsilon = 1e-8 ) =
514
+ OptimiserChain (Adam (eta, beta, epsilon), WeightDecay (lambda))
515
+
507
516
"""
508
517
AdaBelief(η = 0.001, β = (0.9, 0.999), ϵ = 1e-16)
509
518
0 commit comments