Skip to content

Commit 7af4f4c

Browse files
committed
Use conjugates in optimizers to better learn on complex-valued inputs
When weights are complex, the deltas to them will also be complex. In all optimizers that need a second-order estimate of gradient statistics, we generally want to use the `x * conj(x)` pattern, rather than `x^2`.
1 parent ea26f45 commit 7af4f4c

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/optimise/optimisers.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
141141
function apply!(o::RMSProp, x, Δ)
142142
η, ρ = o.eta, o.rho
143143
acc = get!(() -> zero(x), o.acc, x)::typeof(x)
144-
@. acc = ρ * acc + (1 - ρ) * Δ^2
144+
@. acc = ρ * acc + (1 - ρ) * Δ * conj(Δ)
145145
@. Δ *= η / (acc + ϵ)
146146
end
147147

@@ -179,7 +179,7 @@ function apply!(o::ADAM, x, Δ)
179179
end :: Tuple{typeof(x),typeof(x),Vector{Float64}}
180180

181181
@. mt = β[1] * mt + (1 - β[1]) * Δ
182-
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
182+
@. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ)
183183
@. Δ = mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + ϵ) * η
184184
βp .= βp .* β
185185

@@ -221,7 +221,7 @@ function apply!(o::RADAM, x, Δ)
221221
end :: Tuple{typeof(x),typeof(x),Vector{Float64},Ref{Int}}
222222

223223
@. mt = β[1] * mt + (1 - β[1]) * Δ
224-
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
224+
@. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ)
225225
ρ = ρ∞ - 2t[] * βp[2] / (1 - βp[2])
226226
if ρ > 4
227227
r = sqrt((ρ-4)*-2)*ρ∞/((ρ∞-4)*(ρ∞-2)*ρ))
@@ -311,7 +311,7 @@ function apply!(o::OADAM, x, Δ)
311311
end :: Tuple{typeof(x),typeof(x),typeof(x),Vector{Float64}}
312312

313313
@. mt = β[1] * mt + (1 - β[1]) * Δ
314-
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
314+
@. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ)
315315
@. Δ = -Δ_
316316
@. Δ_ = η * mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + ϵ)
317317
@. Δ += 2Δ_
@@ -348,7 +348,7 @@ ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
348348
function apply!(o::ADAGrad, x, Δ)
349349
η = o.eta
350350
acc = get!(() -> fill!(similar(x), ϵ), o.acc, x)::typeof(x)
351-
@. acc += Δ^2
351+
@. acc += Δ * conj(Δ)
352352
@. Δ *= η / (acc + ϵ)
353353
end
354354

@@ -379,11 +379,11 @@ ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict())
379379
function apply!(o::ADADelta, x, Δ)
380380
ρ = o.rho
381381
acc, Δacc = get!(() -> (zero(x), zero(x)), o.state, x)::NTuple{2,typeof(x)}
382-
@. acc = ρ * acc + (1 - ρ) * Δ^2
382+
@. acc = ρ * acc + (1 - ρ) * Δ * conj(Δ)
383383
# DON'T remove epsilon from numerator
384384
# or even out of the square roots
385385
@. Δ *= (Δacc + ϵ) / (acc + ϵ)
386-
@. Δacc = ρ * Δacc + (1 - ρ) * Δ^2
386+
@. Δacc = ρ * Δacc + (1 - ρ) * Δ * conj(Δ)
387387
return Δ
388388
end
389389

@@ -463,7 +463,7 @@ function apply!(o::NADAM, x, Δ)
463463
β1p, β2p = βp
464464

465465
@. mt = β[1] * mt + (1 - β[1]) * Δ
466-
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
466+
@. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ)
467467
@. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / ((vt * β[2] / (1 - β2p)) + ϵ) * η
468468
βp .= βp .* β
469469

@@ -524,7 +524,7 @@ function apply!(o::AdaBelief, x, Δ)
524524
η, β = o.eta, o.beta
525525
mt, st = get!(() -> (zero(x), zero(x)), o.state, x)::NTuple{2,typeof(x)}
526526
@. mt = β[1] * mt + (1 - β[1]) * Δ
527-
@. st = β[2] * st + (1 - β[2]) *- mt)^2
527+
@. st = β[2] * st + (1 - β[2]) *- mt) * conj- mt)
528528
@. Δ = η * mt / ((st) + ϵ)
529529
return Δ
530530
end

0 commit comments

Comments
 (0)