Skip to content

Commit 2c702c6

Browse files
authored
Merge pull request #93 from MilkshakeForReal/RProp
Implement Rprop
2 parents 746317b + b6a675a commit 2c702c6

File tree

3 files changed

+46
-4
lines changed

3 files changed

+46
-4
lines changed

src/Optimisers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ include("destructure.jl")
1010
export destructure
1111

1212
include("rules.jl")
13-
export Descent, Adam, Momentum, Nesterov, RMSProp,
13+
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
1414
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
1515
WeightDecay, ClipGrad, ClipNorm, OptimiserChain
1616

src/rules.jl

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ struct RMSProp{T} <: AbstractRule
110110
epsilon::T
111111
centred::Bool
112112
end
113+
113114
RMSProp= 1f-3, ρ = 9f-1, ϵ = eps(typeof(η)); centred::Bool = false, centered::Bool = false) =
114115
RMSProp{typeof(η)}(η, ρ, ϵ, centred | centered)
115116

@@ -135,6 +136,47 @@ function Base.show(io::IO, o::RMSProp)
135136
print(io, "; centred = ", o.centred, ")")
136137
end
137138

139+
140+
"""
141+
Rprop(η = 1f-3, ℓ = (5f-1, 1.2f0), Γ = (1f-6, 50f0))
142+
143+
Optimizer using the
144+
[Rprop](https://ieeexplore.ieee.org/document/298623) algorithm. A full-batch
145+
learning algorithm that depends only on the sign of the gradient.
146+
147+
# Parameters
148+
- Learning rate (`η`): Amount by which gradients are discounted before updating
149+
the weights.
150+
151+
- Scaling factors (`ℓ::Tuple`): Multiplicative increase and decrease factors.
152+
153+
- Step sizes (`Γ::Tuple`): Mminimal and maximal allowed step sizes.
154+
"""
155+
struct Rprop{T} <: AbstractRule
156+
eta::T
157+
ell::Tuple{T,T}
158+
gamma::Tuple{T,T}
159+
end
160+
161+
Rprop= 1f-3, ℓ = (5f-1, 1.2f0), Γ = (1f-6, 50f0)) = Rprop{typeof(η)}(η, ℓ, Γ)
162+
163+
init(o::Rprop, x::AbstractArray) = (zero(x), onevalue(o.eta, x))
164+
165+
function apply!(o::Rprop, state, x, dx)
166+
ℓ, Γ = o.ell, o.gamma
167+
g, η = state
168+
169+
η = broadcast(g, η, dx) do g, η, dx
170+
g * dx > 0 ? min* ℓ[2], Γ[2]) : g * dx < 0 ? max* ℓ[1], Γ[1]) : η
171+
end
172+
g = broadcast(g, dx) do g, dx
173+
g * dx < 0 ? zero(dx) : dx
174+
end
175+
dx′ = @lazy η * sign(g)
176+
177+
return (g, η), dx′
178+
end
179+
138180
"""
139181
Adam(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η)))
140182
@@ -584,4 +626,4 @@ function Base.show(io::IO, c::OptimiserChain)
584626
print(io, "OptimiserChain(")
585627
join(io, c.opts, ", ")
586628
print(io, ")")
587-
end
629+
end

test/rules.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Random.seed!(1)
66

77
RULES = [
88
# All the rules at default settings:
9-
Descent(), Adam(), Momentum(), Nesterov(), RMSProp(),
9+
Descent(), Adam(), Momentum(), Nesterov(), Rprop(), RMSProp(),
1010
AdaGrad(), AdaMax(), AdaDelta(), AMSGrad(), NAdam(),
1111
AdamW(), RAdam(), OAdam(), AdaBelief(),
1212
# A few chained combinations:
@@ -39,7 +39,7 @@ end
3939
@test iloss(rand(10, 10), w, w′) > 1
4040
st = Optimisers.setup(o, w)
4141
for t = 1:10^5
42-
x = rand(10)
42+
x = rand(10, 20)
4343
gs = loggradient(o)(w -> iloss(x, w, w′), w)
4444
st, w = Optimisers.update!(st, w, gs...)
4545
end

0 commit comments

Comments
 (0)