Skip to content

Commit 7f375aa

Browse files
Merge pull request #1819 from cossio/eps
make eps a parameter of optimisers
2 parents 2399588 + 43279cc commit 7f375aa

File tree

1 file changed

+32
-22
lines changed

1 file changed

+32
-22
lines changed

src/optimise/optimisers.jl

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -133,16 +133,17 @@ opt = RMSProp(0.002, 0.95)
133133
mutable struct RMSProp <: AbstractOptimiser
134134
eta::Float64
135135
rho::Float64
136+
epsilon::Float64
136137
acc::IdDict
137138
end
138139

139-
RMSProp= 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
140+
RMSProp= 0.001, ρ = 0.9, ϵ = ϵ) = RMSProp(η, ρ, ϵ, IdDict())
140141

141142
function apply!(o::RMSProp, x, Δ)
142143
η, ρ = o.eta, o.rho
143144
acc = get!(() -> zero(x), o.acc, x)::typeof(x)
144145
@. acc = ρ * acc + (1 - ρ) * Δ * conj(Δ)
145-
@. Δ *= η / (acc + ϵ)
146+
@. Δ *= η / (acc + o.epsilon)
146147
end
147148

148149
"""
@@ -166,10 +167,11 @@ opt = ADAM(0.001, (0.9, 0.8))
166167
mutable struct ADAM <: AbstractOptimiser
167168
eta::Float64
168169
beta::Tuple{Float64,Float64}
170+
epsilon::Float64
169171
state::IdDict
170172
end
171173

172-
ADAM= 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict())
174+
ADAM= 0.001, β = (0.9, 0.999), ϵ = ϵ) = ADAM(η, β, ϵ, IdDict())
173175

174176
function apply!(o::ADAM, x, Δ)
175177
η, β = o.eta, o.beta
@@ -180,7 +182,7 @@ function apply!(o::ADAM, x, Δ)
180182

181183
@. mt = β[1] * mt + (1 - β[1]) * Δ
182184
@. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ)
183-
@. Δ = mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + ϵ) * η
185+
@. Δ = mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + o.epsilon) * η
184186
βp .= βp .* β
185187

186188
return Δ
@@ -207,10 +209,11 @@ opt = RADAM(0.001, (0.9, 0.8))
207209
mutable struct RADAM <: AbstractOptimiser
208210
eta::Float64
209211
beta::Tuple{Float64,Float64}
212+
epsilon::Float64
210213
state::IdDict
211214
end
212215

213-
RADAM= 0.001, β = (0.9, 0.999)) = RADAM(η, β, IdDict())
216+
RADAM= 0.001, β = (0.9, 0.999), ϵ = ϵ) = RADAM(η, β, ϵ, IdDict())
214217

215218
function apply!(o::RADAM, x, Δ)
216219
η, β = o.eta, o.beta
@@ -225,7 +228,7 @@ function apply!(o::RADAM, x, Δ)
225228
ρ = ρ∞ - 2t[] * βp[2] / (1 - βp[2])
226229
if ρ > 4
227230
r = sqrt((ρ-4)*-2)*ρ∞/((ρ∞-4)*(ρ∞-2)*ρ))
228-
@. Δ = mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + ϵ) * η * r
231+
@. Δ = mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + o.epsilon) * η * r
229232
else
230233
@. Δ = mt / (1 - βp[1]) * η
231234
end
@@ -256,10 +259,11 @@ opt = AdaMax(0.001, (0.9, 0.995))
256259
mutable struct AdaMax <: AbstractOptimiser
257260
eta::Float64
258261
beta::Tuple{Float64,Float64}
262+
epsilon::Float64
259263
state::IdDict
260264
end
261265

262-
AdaMax= 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict())
266+
AdaMax= 0.001, β = (0.9, 0.999), ϵ = ϵ) = AdaMax(η, β, ϵ, IdDict())
263267

264268
function apply!(o::AdaMax, x, Δ)
265269
η, β = o.eta, o.beta
@@ -270,7 +274,7 @@ function apply!(o::AdaMax, x, Δ)
270274

271275
@. mt = β[1] * mt + (1 - β[1]) * Δ
272276
@. ut = max(β[2] * ut, abs(Δ))
273-
@. Δ =/(1 - βp[1])) * mt/(ut + ϵ)
277+
@. Δ =/(1 - βp[1])) * mt/(ut + o.epsilon)
274278
βp .= βp .* β
275279

276280
return Δ
@@ -298,10 +302,11 @@ opt = OADAM(0.001, (0.9, 0.995))
298302
mutable struct OADAM <: AbstractOptimiser
299303
eta::Float64
300304
beta::Tuple{Float64,Float64}
305+
epsilon::Float64
301306
state::IdDict
302307
end
303308

304-
OADAM= 0.001, β = (0.5, 0.9)) = OADAM(η, β, IdDict())
309+
OADAM= 0.001, β = (0.5, 0.9), ϵ = ϵ) = OADAM(η, β, ϵ, IdDict())
305310

306311
function apply!(o::OADAM, x, Δ)
307312
η, β = o.eta, o.beta
@@ -313,7 +318,7 @@ function apply!(o::OADAM, x, Δ)
313318
@. mt = β[1] * mt + (1 - β[1]) * Δ
314319
@. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ)
315320
@. Δ = -Δ_
316-
@. Δ_ = η * mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + ϵ)
321+
@. Δ_ = η * mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + o.epsilon)
317322
@. Δ += 2Δ_
318323
βp .= βp .* β
319324

@@ -340,16 +345,17 @@ opt = ADAGrad(0.001)
340345
"""
341346
mutable struct ADAGrad <: AbstractOptimiser
342347
eta::Float64
348+
epsilon::Float64
343349
acc::IdDict
344350
end
345351

346-
ADAGrad= 0.1) = ADAGrad(η, IdDict())
352+
ADAGrad= 0.1, ϵ = ϵ) = ADAGrad, ϵ, IdDict())
347353

348354
function apply!(o::ADAGrad, x, Δ)
349355
η = o.eta
350-
acc = get!(() -> fill!(similar(x), ϵ), o.acc, x)::typeof(x)
356+
acc = get!(() -> fill!(similar(x), o.epsilon), o.acc, x)::typeof(x)
351357
@. acc += Δ * conj(Δ)
352-
@. Δ *= η / (acc + ϵ)
358+
@. Δ *= η / (acc + o.epsilon)
353359
end
354360

355361
"""
@@ -371,18 +377,19 @@ opt = ADADelta(0.89)
371377
"""
372378
mutable struct ADADelta <: AbstractOptimiser
373379
rho::Float64
380+
epsilon::Float64
374381
state::IdDict
375382
end
376383

377-
ADADelta= 0.9) = ADADelta(ρ, IdDict())
384+
ADADelta= 0.9, ϵ = ϵ) = ADADelta, ϵ, IdDict())
378385

379386
function apply!(o::ADADelta, x, Δ)
380387
ρ = o.rho
381388
acc, Δacc = get!(() -> (zero(x), zero(x)), o.state, x)::NTuple{2,typeof(x)}
382389
@. acc = ρ * acc + (1 - ρ) * Δ * conj(Δ)
383390
# DON'T remove epsilon from numerator
384391
# or even out of the square roots
385-
@. Δ *= (Δacc + ϵ) / (acc + ϵ)
392+
@. Δ *= (Δacc + o.epsilon) / (acc + o.epsilon)
386393
@. Δacc = ρ * Δacc + (1 - ρ) * Δ * conj(Δ)
387394
return Δ
388395
end
@@ -409,22 +416,23 @@ opt = AMSGrad(0.001, (0.89, 0.995))
409416
mutable struct AMSGrad <: AbstractOptimiser
410417
eta::Float64
411418
beta::Tuple{Float64, Float64}
419+
epsilon::Float64
412420
state::IdDict
413421
end
414422

415-
AMSGrad= 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict())
423+
AMSGrad= 0.001, β = (0.9, 0.999), ϵ = ϵ) = AMSGrad(η, β, ϵ, IdDict())
416424

417425
function apply!(o::AMSGrad, x, Δ)
418426
η, β = o.eta, o.beta
419427

420428
mt, vt, v̂t = get!(o.state, x) do
421-
(fill!(similar(x), ϵ), fill!(similar(x), ϵ), fill!(similar(x), ϵ))
429+
(fill!(similar(x), o.epsilon), fill!(similar(x), o.epsilon), fill!(similar(x), o.epsilon))
422430
end :: NTuple{3,typeof(x)}
423431

424432
@. mt = β[1] * mt + (1 - β[1]) * Δ
425433
@. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2
426434
@. v̂t = max(v̂t, vt)
427-
@. Δ = η * mt / (v̂t + ϵ)
435+
@. Δ = η * mt / (v̂t + o.epsilon)
428436
end
429437

430438
"""
@@ -449,10 +457,11 @@ opt = NADAM(0.002, (0.89, 0.995))
449457
mutable struct NADAM <: AbstractOptimiser
450458
eta::Float64
451459
beta::Tuple{Float64, Float64}
460+
epsilon::Float64
452461
state::IdDict
453462
end
454463

455-
NADAM= 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict())
464+
NADAM= 0.001, β = (0.9, 0.999), ϵ = ϵ) = NADAM(η, β, ϵ, IdDict())
456465

457466
function apply!(o::NADAM, x, Δ)
458467
η, β = o.eta, o.beta
@@ -464,7 +473,7 @@ function apply!(o::NADAM, x, Δ)
464473

465474
@. mt = β[1] * mt + (1 - β[1]) * Δ
466475
@. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ)
467-
@. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / ((vt * β[2] / (1 - β2p)) + ϵ) * η
476+
@. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / ((vt * β[2] / (1 - β2p)) + o.epsilon) * η
468477
βp .= βp .* β
469478

470479
return Δ
@@ -515,17 +524,18 @@ opt = AdaBelief(0.001, (0.9, 0.8))
515524
mutable struct AdaBelief
516525
eta::Float64
517526
beta::Tuple{Float64,Float64}
527+
epsilon::Float64
518528
state::IdDict
519529
end
520530

521-
AdaBelief= 0.001, β = (0.9, 0.999)) = AdaBelief(η, β, IdDict())
531+
AdaBelief= 0.001, β = (0.9, 0.999), ϵ = ϵ) = AdaBelief(η, β, ϵ, IdDict())
522532

523533
function apply!(o::AdaBelief, x, Δ)
524534
η, β = o.eta, o.beta
525535
mt, st = get!(() -> (zero(x), zero(x)), o.state, x)::NTuple{2,typeof(x)}
526536
@. mt = β[1] * mt + (1 - β[1]) * Δ
527537
@. st = β[2] * st + (1 - β[2]) *- mt) * conj- mt)
528-
@. Δ = η * mt / ((st) + ϵ)
538+
@. Δ = η * mt / ((st) + o.epsilon)
529539
return Δ
530540
end
531541

0 commit comments

Comments
 (0)