@@ -141,7 +141,7 @@ RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
141
141
function apply! (o:: RMSProp , x, Δ)
142
142
η, ρ = o. eta, o. rho
143
143
acc = get! (() -> zero (x), o. acc, x):: typeof (x)
144
- @. acc = ρ * acc + (1 - ρ) * Δ^ 2
144
+ @. acc = ρ * acc + (1 - ρ) * Δ * conj (Δ)
145
145
@. Δ *= η / (√ acc + ϵ)
146
146
end
147
147
@@ -179,7 +179,7 @@ function apply!(o::ADAM, x, Δ)
179
179
end :: Tuple{typeof(x),typeof(x),Vector{Float64}}
180
180
181
181
@. mt = β[1 ] * mt + (1 - β[1 ]) * Δ
182
- @. vt = β[2 ] * vt + (1 - β[2 ]) * Δ^ 2
182
+ @. vt = β[2 ] * vt + (1 - β[2 ]) * Δ * conj (Δ)
183
183
@. Δ = mt / (1 - βp[1 ]) / (√ (vt / (1 - βp[2 ])) + ϵ) * η
184
184
βp .= βp .* β
185
185
@@ -221,7 +221,7 @@ function apply!(o::RADAM, x, Δ)
221
221
end :: Tuple{typeof(x),typeof(x),Vector{Float64},Ref{Int}}
222
222
223
223
@. mt = β[1 ] * mt + (1 - β[1 ]) * Δ
224
- @. vt = β[2 ] * vt + (1 - β[2 ]) * Δ^ 2
224
+ @. vt = β[2 ] * vt + (1 - β[2 ]) * Δ * conj (Δ)
225
225
ρ = ρ∞ - 2 t[] * βp[2 ] / (1 - βp[2 ])
226
226
if ρ > 4
227
227
r = sqrt ((ρ- 4 )* (ρ- 2 )* ρ∞/ ((ρ∞- 4 )* (ρ∞- 2 )* ρ))
@@ -311,7 +311,7 @@ function apply!(o::OADAM, x, Δ)
311
311
end :: Tuple{typeof(x),typeof(x),typeof(x),Vector{Float64}}
312
312
313
313
@. mt = β[1 ] * mt + (1 - β[1 ]) * Δ
314
- @. vt = β[2 ] * vt + (1 - β[2 ]) * Δ^ 2
314
+ @. vt = β[2 ] * vt + (1 - β[2 ]) * Δ * conj (Δ)
315
315
@. Δ = - Δ_
316
316
@. Δ_ = η * mt / (1 - βp[1 ]) / (√ (vt / (1 - βp[2 ])) + ϵ)
317
317
@. Δ += 2 Δ_
@@ -348,7 +348,7 @@ ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
348
348
function apply! (o:: ADAGrad , x, Δ)
349
349
η = o. eta
350
350
acc = get! (() -> fill! (similar (x), ϵ), o. acc, x):: typeof (x)
351
- @. acc += Δ^ 2
351
+ @. acc += Δ * conj (Δ)
352
352
@. Δ *= η / (√ acc + ϵ)
353
353
end
354
354
@@ -379,11 +379,11 @@ ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict())
379
379
function apply! (o:: ADADelta , x, Δ)
380
380
ρ = o. rho
381
381
acc, Δacc = get! (() -> (zero (x), zero (x)), o. state, x):: NTuple{2,typeof(x)}
382
- @. acc = ρ * acc + (1 - ρ) * Δ^ 2
382
+ @. acc = ρ * acc + (1 - ρ) * Δ * conj (Δ)
383
383
# DON'T remove epsilon from numerator
384
384
# or even out of the square roots
385
385
@. Δ *= √ (Δacc + ϵ) / √ (acc + ϵ)
386
- @. Δacc = ρ * Δacc + (1 - ρ) * Δ^ 2
386
+ @. Δacc = ρ * Δacc + (1 - ρ) * Δ * conj (Δ)
387
387
return Δ
388
388
end
389
389
@@ -463,7 +463,7 @@ function apply!(o::NADAM, x, Δ)
463
463
β1p, β2p = βp
464
464
465
465
@. mt = β[1 ] * mt + (1 - β[1 ]) * Δ
466
- @. vt = β[2 ] * vt + (1 - β[2 ]) * Δ^ 2
466
+ @. vt = β[2 ] * vt + (1 - β[2 ]) * Δ * conj (Δ)
467
467
@. Δ = (β[1 ] * mt / (1 - β[1 ] * β1p) + (1 - β[1 ]) * Δ / (1 - β1p)) / (√ (vt * β[2 ] / (1 - β2p)) + ϵ) * η
468
468
βp .= βp .* β
469
469
@@ -524,7 +524,7 @@ function apply!(o::AdaBelief, x, Δ)
524
524
η, β = o. eta, o. beta
525
525
mt, st = get! (() -> (zero (x), zero (x)), o. state, x):: NTuple{2,typeof(x)}
526
526
@. mt = β[1 ] * mt + (1 - β[1 ]) * Δ
527
- @. st = β[2 ] * st + (1 - β[2 ]) * (Δ - mt)^ 2
527
+ @. st = β[2 ] * st + (1 - β[2 ]) * (Δ - mt) * conj (Δ - mt)
528
528
@. Δ = η * mt / (√ (st) + ϵ)
529
529
return Δ
530
530
end
0 commit comments