Skip to content

Commit 1908a1c

Browse files
authored
Add all-keyword constructors, much like @kwdef (#160)
* add all-keyword constructors * update a few docstrings * docstrings * add tests * one lost γ should be λ
1 parent e60b71e commit 1908a1c

File tree

3 files changed

+37
-16
lines changed

3 files changed

+37
-16
lines changed

src/interface.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,8 @@ like this:
241241
struct Rule
242242
eta::Float64
243243
beta::Tuple{Float64, Float64}
244-
Rule(eta = 0.1, beta = (0.7, 0.8)) = eta < 0 ? error() : new(eta, beta)
244+
Rule(eta, beta = (0.7, 0.8)) = eta < 0 ? error() : new(eta, beta)
245+
Rule(; eta = 0.1, beta = (0.7, 0.8)) = Rule(eta, beta)
245246
end
246247
```
247248
Any field called `eta` is assumed to be a learning rate, and cannot be negative.
@@ -259,12 +260,17 @@ macro def(expr)
259260
lines[i] = :($name::$typeof($val))
260261
end
261262
rule = Meta.isexpr(expr.args[2], :<:) ? expr.args[2].args[1] : expr.args[2]
263+
params = [Expr(:kw, nv...) for nv in zip(names,vals)]
262264
check = :eta in names ? :(eta < 0 && throw(DomainError(eta, "the learning rate cannot be negative"))) : nothing
263-
inner = :(function $rule($([Expr(:kw, nv...) for nv in zip(names,vals)]...))
265+
# Positional-argument method, has defaults for all but the first arg:
266+
inner = :(function $rule($(names[1]), $(params[2:end]...))
264267
$check
265268
new($(names...))
266269
end)
267-
push!(lines, inner)
270+
# Keyword-argument method. (Made an inner constructor only to allow
271+
# resulting structs to be @doc... cannot if macro returns a block.)
272+
kwmethod = :($rule(; $(params...)) = $rule($(names...)))
273+
push!(lines, inner, kwmethod)
268274
esc(expr)
269275
end
270276

src/rules.jl

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,19 @@
88

99
"""
1010
Descent(η = 1f-1)
11+
Descent(; eta)
1112
1213
Classic gradient descent optimiser with learning rate `η`.
1314
For each parameter `p` and its gradient `dp`, this runs `p -= η*dp`.
1415
1516
# Parameters
16-
- Learning rate (`η`): Amount by which gradients are discounted before updating
17+
- Learning rate (`η == eta`): Amount by which gradients are discounted before updating
1718
the weights.
1819
"""
1920
struct Descent{T} <: AbstractRule
2021
eta::T
2122
end
22-
Descent() = Descent(1f-1)
23+
Descent(; eta = 1f-1) = Descent(eta)
2324

2425
init(o::Descent, x::AbstractArray) = nothing
2526

@@ -37,13 +38,14 @@ end
3738

3839
"""
3940
Momentum(η = 0.01, ρ = 0.9)
41+
Momentum(; [eta, rho])
4042
4143
Gradient descent optimizer with learning rate `η` and momentum `ρ`.
4244
4345
# Parameters
44-
- Learning rate (`η`): Amount by which gradients are discounted before updating
46+
- Learning rate (`η == eta`): Amount by which gradients are discounted before updating
4547
the weights.
46-
- Momentum (`ρ`): Controls the acceleration of gradient descent in the
48+
- Momentum (`ρ == rho`): Controls the acceleration of gradient descent in the
4749
prominent direction, in effect dampening oscillations.
4850
"""
4951
@def struct Momentum <: AbstractRule
@@ -89,6 +91,7 @@ end
8991

9092
"""
9193
RMSProp(η = 0.001, ρ = 0.9, ϵ = 1e-8; centred = false)
94+
RMSProp(; [eta, rho, epsilon, centred])
9295
9396
Optimizer using the
9497
[RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
@@ -99,11 +102,11 @@ generally don't need tuning.
99102
gradients by an estimate their variance, instead of their second moment.
100103
101104
# Parameters
102-
- Learning rate (`η`): Amount by which gradients are discounted before updating
105+
- Learning rate (`η == eta`): Amount by which gradients are discounted before updating
103106
the weights.
104-
- Momentum (`ρ`): Controls the acceleration of gradient descent in the
107+
- Momentum (`ρ == rho`): Controls the acceleration of gradient descent in the
105108
prominent direction, in effect dampening oscillations.
106-
- Machine epsilon (`ϵ`): Constant to prevent division by zero
109+
- Machine epsilon (`ϵ == epsilon`): Constant to prevent division by zero
107110
(no need to change default)
108111
- Keyword `centred` (or `centered`): Indicates whether to use centred variant
109112
of the algorithm.
@@ -115,10 +118,11 @@ struct RMSProp <: AbstractRule
115118
centred::Bool
116119
end
117120

118-
function RMSProp = 0.001, ρ = 0.9, ϵ = 1e-8; centred::Bool = false, centered::Bool = false)
121+
function RMSProp(η, ρ = 0.9, ϵ = 1e-8; centred::Bool = false, centered::Bool = false)
119122
η < 0 && throw(DomainError(η, "the learning rate cannot be negative"))
120123
RMSProp(η, ρ, ϵ, centred | centered)
121124
end
125+
RMSProp(; eta = 0.001, rho = 0.9, epsilon = 1e-8, kw...) = RMSProp(eta, rho, epsilon; kw...)
122126

123127
init(o::RMSProp, x::AbstractArray) = (zero(x), o.centred ? zero(x) : false)
124128

@@ -488,22 +492,27 @@ end
488492

489493
"""
490494
AdamW(η = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8)
495+
AdamW(; [eta, beta, lambda, epsilon])
491496
492497
[AdamW](https://arxiv.org/abs/1711.05101) is a variant of Adam fixing (as in repairing) its
493498
weight decay regularization.
499+
Implemented as an [`OptimiserChain`](@ref) of [`Adam`](@ref) and [`WeightDecay`](@ref)`.
494500
495501
# Parameters
496-
- Learning rate (`η`): Amount by which gradients are discounted before updating
502+
- Learning rate (`η == eta`): Amount by which gradients are discounted before updating
497503
the weights.
498-
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
504+
- Decay of momentums (`β::Tuple == beta`): Exponential decay for the first (β1) and the
499505
second (β2) momentum estimate.
500-
- Weight decay (`λ`): Controls the strength of ``L_2`` regularisation.
501-
- Machine epsilon (`ϵ`): Constant to prevent division by zero
506+
- Weight decay (`λ == lambda`): Controls the strength of ``L_2`` regularisation.
507+
- Machine epsilon (`ϵ == epsilon`): Constant to prevent division by zero
502508
(no need to change default)
503509
"""
504-
AdamW = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8) =
510+
AdamW(η, β = (0.9, 0.999), λ = 0.0, ϵ = 1e-8) =
505511
OptimiserChain(Adam(η, β, ϵ), WeightDecay(λ))
506512

513+
AdamW(; eta = 0.001, beta = (0.9, 0.999), lambda = 0, epsilon = 1e-8) =
514+
OptimiserChain(Adam(eta, beta, epsilon), WeightDecay(lambda))
515+
507516
"""
508517
AdaBelief(η = 0.001, β = (0.9, 0.999), ϵ = 1e-16)
509518

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,12 @@ end
330330
@test_throws ArgumentError Optimisers.thaw!(m)
331331
end
332332

333+
@testset "keyword arguments" begin
334+
@test Nesterov(rho=0.8, eta=0.1) === Nesterov(0.1, 0.8)
335+
@test AdamW(lambda=0.3).opts[1] == Adam()
336+
@test AdamW(lambda=0.3).opts[2] == WeightDecay(0.3)
337+
end
338+
333339
@testset "forgotten gradient" begin
334340
x = [1.0, 2.0]
335341
sx = Optimisers.setup(Descent(), x)

0 commit comments

Comments
 (0)