Skip to content

Commit e60b71e

Browse files
authored
WeightDecay for L1 norm (#159)
* WeightDecay for L1 norm * better words * change to lambda alpha, add tests * change to lambda, add tests * tweaks * shashed in October - makes two structs instead * version with simple SignDecay instead * change SignDecay penalty to be called kappa * restore depwarn for WeightDecay, was called gamma * change kappa back to lambda
1 parent 6473c45 commit e60b71e

File tree

5 files changed

+67
-16
lines changed

5 files changed

+67
-16
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Optimisers"
22
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
33
authors = ["Mike J Innes <mike.j.innes@gmail.com>"]
4-
version = "0.3.1"
4+
version = "0.3.2"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/Optimisers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ export destructure
1414
include("rules.jl")
1515
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
1616
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
17-
WeightDecay, ClipGrad, ClipNorm, OptimiserChain, Lion,
17+
WeightDecay, SignDecay, ClipGrad, ClipNorm, OptimiserChain, Lion,
1818
AccumGrad
1919

2020
###

src/rules.jl

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ function apply!(o::NAdam, state, x::AbstractArray{T}, dx) where T
487487
end
488488

489489
"""
490-
AdamW(η = 0.001, β = (0.9, 0.999), γ = 0, ϵ = 1e-8)
490+
AdamW(η = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8)
491491
492492
[AdamW](https://arxiv.org/abs/1711.05101) is a variant of Adam fixing (as in repairing) its
493493
weight decay regularization.
@@ -497,12 +497,12 @@ weight decay regularization.
497497
the weights.
498498
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
499499
second (β2) momentum estimate.
500-
- Weight decay (`γ`): Decay applied to weights during optimisation.
500+
- Weight decay (`λ`): Controls the strength of ``L_2`` regularisation.
501501
- Machine epsilon (`ϵ`): Constant to prevent division by zero
502502
(no need to change default)
503503
"""
504-
AdamW= 0.001, β = (0.9, 0.999), γ = 0, ϵ = 1e-8) =
505-
OptimiserChain(Adam(η, β, ϵ), WeightDecay(γ))
504+
AdamW= 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8) =
505+
OptimiserChain(Adam(η, β, ϵ), WeightDecay(λ))
506506

507507
"""
508508
AdaBelief(η = 0.001, β = (0.9, 0.999), ϵ = 1e-16)
@@ -538,35 +538,79 @@ function apply!(o::AdaBelief, state, x::AbstractArray{T}, dx) where T
538538
end
539539

540540
"""
541-
WeightDecay(γ = 5e-4)
541+
WeightDecay(λ = 5e-4)
542542
543-
Decay weights by ``γ``, that is, add `γ .* x` to the gradient `x̄` which will be
544-
subtracted from `x`.
543+
Implements ``L_2`` regularisation, also known as ridge regression,
544+
when composed with other rules as the first transformation in an [`OptimiserChain`](@ref).
545545
546-
Typically composed with other optimisers as the first transformation in an [`OptimiserChain`](@ref).
547-
This is equivalent to adding ``L_2`` regularization with coefficient ``γ`` to the loss.
546+
It does this by adding `λ .* x` to the gradient. This is equivalent to adding
547+
`λ/2 * sum(abs2, x) == λ/2 * norm(x)^2` to the loss.
548+
549+
See also [`SignDecay`] for ``L_1`` normalisation.
548550
549551
# Parameters
550-
- Weight decay (`γ`): Decay applied to weights during optimisation.
552+
- Penalty (`λ ≥ 0`): Controls the strength of the regularisation.
551553
"""
552554
@def struct WeightDecay <: AbstractRule
553-
gamma = 5e-4
555+
lambda = 5e-4
554556
end
555557

556558
init(o::WeightDecay, x::AbstractArray) = nothing
557559

558560
function apply!(o::WeightDecay, state, x::AbstractArray{T}, dx) where T
559-
γ = T(o.gamma)
560-
dx′ = @lazy dx + γ * x
561+
λ = T(o.lambda)
562+
dx′ = @lazy dx + λ * x
561563

562564
return state, dx′
563565
end
564566

567+
function adjust(r::WeightDecay; gamma = nothing, kw...)
568+
if isnothing(gamma)
569+
return _adjust(r, NamedTuple(kw))
570+
else
571+
Base.depwarn("The strength of WeightDecay is now field :lambda, not :gamma", :adjust, force=true)
572+
nt = (; lambda = gamma, NamedTuple(kw)...)
573+
return _adjust(r, nt)
574+
end
575+
end
576+
577+
"""
578+
SignDecay(λ = 1e-3)
579+
580+
Implements ``L_1`` regularisation, also known as LASSO regression,
581+
when composed with other rules as the first transformation in an [`OptimiserChain`](@ref).
582+
583+
It does this by adding `λ .* sign(x)` to the gradient. This is equivalent to adding
584+
`λ * sum(abs, x) == λ * norm(x, 1)` to the loss.
585+
586+
See also [`WeightDecay`] for ``L_2`` normalisation.
587+
They can be used together: `OptimiserChain(SignDecay(0.012), WeightDecay(0.034), Adam())`
588+
is equivalent to adding `0.012 * norm(x, 1) + 0.017 * norm(x, 2)^2` to the loss function.
589+
590+
# Parameters
591+
- Penalty (`λ ≥ 0`): Controls the strength of the regularisation.
592+
"""
593+
@def struct SignDecay <: AbstractRule
594+
lambda = 1e-3
595+
end
596+
597+
init(o::SignDecay, x::AbstractArray) = nothing
598+
599+
function apply!(o::SignDecay, state, x::AbstractArray{T}, dx) where T
600+
λ = T(o.lambda)
601+
dx′ = @lazy dx + λ * sign(x)
602+
603+
return state, dx′
604+
end
605+
606+
565607
"""
566608
ClipGrad(δ = 10)
567609
568610
Restricts every gradient component to obey `-δ ≤ dx[i] ≤ δ`.
569611
612+
Typically composed with other rules using [`OptimiserChain`](@ref).
613+
570614
See also [`ClipNorm`](@ref).
571615
"""
572616
@def struct ClipGrad <: AbstractRule
@@ -591,6 +635,8 @@ to stay at this threshold (unless `p==0`).
591635
Throws an error if the norm is infinite or `NaN`,
592636
which you can turn off with `throw = false`.
593637
638+
Typically composed with other rules using [`OptimiserChain`](@ref).
639+
594640
See also [`ClipGrad`](@ref).
595641
"""
596642
struct ClipNorm <: AbstractRule

test/rules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ RULES = [
1010
AdaGrad(), AdaMax(), AdaDelta(), AMSGrad(), NAdam(),
1111
AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(),
1212
# A few chained combinations:
13-
OptimiserChain(WeightDecay(), Adam(0.001)),
13+
OptimiserChain(SignDecay(0.001), Adam(0.001)),
1414
OptimiserChain(ClipNorm(), Adam(0.001)),
1515
OptimiserChain(ClipGrad(0.5), Momentum()),
1616
OptimiserChain(WeightDecay(), OAdam(), ClipGrad(1)),

test/runtests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ end
137137
@testset "OptimiserChain" begin
138138
x = [1, 10, 100.0]; dx = [1, 2, 3.0];
139139
@test Optimisers.update(Optimisers.setup(WeightDecay(0.1), x), x, dx)[2] [1-0.1-1, 10-1-2, 100-10-3]
140+
@test Optimisers.update(Optimisers.setup(SignDecay(0.1), x), x, dx)[2] [1-0.1-1, 10-0.1-2, 100-0.1-3]
140141
@test Optimisers.update(Optimisers.setup(ClipGrad(2), x), x, dx)[2] [1-1, 10-2, 100-2]
141142

142143
o2 = OptimiserChain(ClipGrad(2), WeightDecay(0.1))
@@ -154,6 +155,10 @@ end
154155

155156
o0 = OptimiserChain()
156157
@test Optimisers.update(Optimisers.setup(o0, x), x, dx)[2] [1-1,10-2,100-3]
158+
159+
# L1 norm via sign
160+
xm = [1, -10, 100.0]; dxm = [3, 2, -1];
161+
@test Optimisers.update(Optimisers.setup(SignDecay(0.1), xm), xm, dxm)[2] [1-0.1-3, -10+0.1-2, 100-0.1+1]
157162
end
158163

159164
@testset "trainable subset" begin

0 commit comments

Comments
 (0)