@@ -487,7 +487,7 @@ function apply!(o::NAdam, state, x::AbstractArray{T}, dx) where T
487
487
end
488
488
489
489
"""
490
- AdamW(η = 0.001, β = (0.9, 0.999), γ = 0, ϵ = 1e-8)
490
+ AdamW(η = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8)
491
491
492
492
[AdamW](https://arxiv.org/abs/1711.05101) is a variant of Adam fixing (as in repairing) its
493
493
weight decay regularization.
@@ -497,12 +497,12 @@ weight decay regularization.
497
497
the weights.
498
498
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
499
499
second (β2) momentum estimate.
500
- - Weight decay (`γ `): Decay applied to weights during optimisation .
500
+ - Weight decay (`λ `): Controls the strength of ``L_2`` regularisation .
501
501
- Machine epsilon (`ϵ`): Constant to prevent division by zero
502
502
(no need to change default)
503
503
"""
504
- AdamW (η = 0.001 , β = (0.9 , 0.999 ), γ = 0 , ϵ = 1e-8 ) =
505
- OptimiserChain (Adam (η, β, ϵ), WeightDecay (γ ))
504
+ AdamW (η = 0.001 , β = (0.9 , 0.999 ), λ = 0 , ϵ = 1e-8 ) =
505
+ OptimiserChain (Adam (η, β, ϵ), WeightDecay (λ ))
506
506
507
507
"""
508
508
AdaBelief(η = 0.001, β = (0.9, 0.999), ϵ = 1e-16)
@@ -538,35 +538,79 @@ function apply!(o::AdaBelief, state, x::AbstractArray{T}, dx) where T
538
538
end
539
539
540
540
"""
541
- WeightDecay(γ = 5e-4)
541
+ WeightDecay(λ = 5e-4)
542
542
543
- Decay weights by ``γ``, that is, add `γ .* x` to the gradient `x̄` which will be
544
- subtracted from `x` .
543
+ Implements ``L_2`` regularisation, also known as ridge regression,
544
+ when composed with other rules as the first transformation in an [`OptimiserChain`](@ref) .
545
545
546
- Typically composed with other optimisers as the first transformation in an [`OptimiserChain`](@ref).
547
- This is equivalent to adding ``L_2`` regularization with coefficient ``γ`` to the loss.
546
+ It does this by adding `λ .* x` to the gradient. This is equivalent to adding
547
+ `λ/2 * sum(abs2, x) == λ/2 * norm(x)^2` to the loss.
548
+
549
+ See also [`SignDecay`] for ``L_1`` normalisation.
548
550
549
551
# Parameters
550
- - Weight decay (`γ `): Decay applied to weights during optimisation .
552
+ - Penalty (`λ ≥ 0 `): Controls the strength of the regularisation .
551
553
"""
552
554
@def struct WeightDecay <: AbstractRule
553
- gamma = 5e-4
555
+ lambda = 5e-4
554
556
end
555
557
556
558
init (o:: WeightDecay , x:: AbstractArray ) = nothing
557
559
558
560
function apply! (o:: WeightDecay , state, x:: AbstractArray{T} , dx) where T
559
- γ = T (o. gamma )
560
- dx′ = @lazy dx + γ * x
561
+ λ = T (o. lambda )
562
+ dx′ = @lazy dx + λ * x
561
563
562
564
return state, dx′
563
565
end
564
566
567
+ function adjust (r:: WeightDecay ; gamma = nothing , kw... )
568
+ if isnothing (gamma)
569
+ return _adjust (r, NamedTuple (kw))
570
+ else
571
+ Base. depwarn (" The strength of WeightDecay is now field :lambda, not :gamma" , :adjust , force= true )
572
+ nt = (; lambda = gamma, NamedTuple (kw)... )
573
+ return _adjust (r, nt)
574
+ end
575
+ end
576
+
577
+ """
578
+ SignDecay(λ = 1e-3)
579
+
580
+ Implements ``L_1`` regularisation, also known as LASSO regression,
581
+ when composed with other rules as the first transformation in an [`OptimiserChain`](@ref).
582
+
583
+ It does this by adding `λ .* sign(x)` to the gradient. This is equivalent to adding
584
+ `λ * sum(abs, x) == λ * norm(x, 1)` to the loss.
585
+
586
+ See also [`WeightDecay`] for ``L_2`` normalisation.
587
+ They can be used together: `OptimiserChain(SignDecay(0.012), WeightDecay(0.034), Adam())`
588
+ is equivalent to adding `0.012 * norm(x, 1) + 0.017 * norm(x, 2)^2` to the loss function.
589
+
590
+ # Parameters
591
+ - Penalty (`λ ≥ 0`): Controls the strength of the regularisation.
592
+ """
593
+ @def struct SignDecay <: AbstractRule
594
+ lambda = 1e-3
595
+ end
596
+
597
+ init (o:: SignDecay , x:: AbstractArray ) = nothing
598
+
599
+ function apply! (o:: SignDecay , state, x:: AbstractArray{T} , dx) where T
600
+ λ = T (o. lambda)
601
+ dx′ = @lazy dx + λ * sign (x)
602
+
603
+ return state, dx′
604
+ end
605
+
606
+
565
607
"""
566
608
ClipGrad(δ = 10)
567
609
568
610
Restricts every gradient component to obey `-δ ≤ dx[i] ≤ δ`.
569
611
612
+ Typically composed with other rules using [`OptimiserChain`](@ref).
613
+
570
614
See also [`ClipNorm`](@ref).
571
615
"""
572
616
@def struct ClipGrad <: AbstractRule
@@ -591,6 +635,8 @@ to stay at this threshold (unless `p==0`).
591
635
Throws an error if the norm is infinite or `NaN`,
592
636
which you can turn off with `throw = false`.
593
637
638
+ Typically composed with other rules using [`OptimiserChain`](@ref).
639
+
594
640
See also [`ClipGrad`](@ref).
595
641
"""
596
642
struct ClipNorm <: AbstractRule
0 commit comments