Skip to content

Commit cbc1275

Browse files
Merge #1776
1776: Use conjugates in optimizers to better learn on complex-valued inputs r=DhairyaLGandhi a=staticfloat When weights are complex, the deltas to them will also be complex. In all optimizers that need a second-order estimate of gradient statistics, we generally want to use the `x * conj(x)` pattern, rather than `x^2`. We can see the effect this has on ADAM with the following test: ```julia begin # This model will learn `W = I` and `bias = 0` complex_init(dims...) = Flux.glorot_uniform(dims...) .+ 1im .* Flux.glorot_uniform(dims...) model = Chain( Dense(4, 4, tanh; init=complex_init), Dense(4, 16, tanh; init=complex_init), Dense(16, 4, tanh; init=complex_init), Dense(4, 4, tanh; init=complex_init), ) # Loss function; note we don't need the `abs()` if we update `Flux.Losses.mse()` as below function loss(x) return abs.(Flux.Losses.mse(model(x), x)) end # Keep track of loss from epoch to epoch losses = Float64[] dataset = [(randn(ComplexF32, 4, 10),)] params = Flux.params(model) opt = Flux.Optimise.ADAM(0.001) for epoch_idx in 1:10000 Flux.train!(loss, params, dataset, opt) epoch_loss = loss(dataset[1][1]) push!(losses, epoch_loss) if epoch_idx % 100 == 0 `@info("epoch` done", epoch_idx, epoch_loss) end end # Plot the loss fig = Figure() meta_ax = Axis(fig[1,1]) lines!(meta_ax, log.(losses); label="Training loss") fig[1,2] = Legend(fig, meta_ax, "Learning Stats") fig end ``` The training loss before the fix looks like this: ![without_workaround](https://user-images.githubusercontent.com/130920/142955143-385c5ca9-b2d7-4129-aae0-152741661689.png) Whereas after both of these commits, it looks like this: ![with_workaround](https://user-images.githubusercontent.com/130920/142955168-807943d7-a2d4-4f7a-82a6-fbab0610e407.png) Note that while the absolute value of the loss is actually comparable in this simple example, the loss landscape is significantly more chaotic. With a higher learning rate, the "fixed" version is able to learn much faster: ![download-1](https://user-images.githubusercontent.com/130920/142955367-e945e6c2-7045-42f7-8a7f-9135ee40c5b4.png) Whereas the unfixed version simply diverges: ![download-2](https://user-images.githubusercontent.com/130920/142955420-8f32bb3c-5add-4fcb-86a6-eff7fac6dfaf.png) Co-authored-by: Elliot Saba <staticfloat@gmail.com>
2 parents bb88c55 + 8c3d852 commit cbc1275

File tree

4 files changed

+51
-10
lines changed

4 files changed

+51
-10
lines changed

src/losses/functions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ julia> Flux.mse(y_model, y_true)
4444
"""
4545
function mse(ŷ, y; agg = mean)
4646
_check_sizes(ŷ, y)
47-
agg((ŷ .- y) .^ 2)
47+
error =.- y
48+
real(agg(error .* conj(error)))
4849
end
4950

5051
"""

src/optimise/optimisers.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
141141
function apply!(o::RMSProp, x, Δ)
142142
η, ρ = o.eta, o.rho
143143
acc = get!(() -> zero(x), o.acc, x)::typeof(x)
144-
@. acc = ρ * acc + (1 - ρ) * Δ^2
144+
@. acc = ρ * acc + (1 - ρ) * Δ * conj(Δ)
145145
@. Δ *= η / (acc + ϵ)
146146
end
147147

@@ -179,7 +179,7 @@ function apply!(o::ADAM, x, Δ)
179179
end :: Tuple{typeof(x),typeof(x),Vector{Float64}}
180180

181181
@. mt = β[1] * mt + (1 - β[1]) * Δ
182-
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
182+
@. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ)
183183
@. Δ = mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + ϵ) * η
184184
βp .= βp .* β
185185

@@ -221,7 +221,7 @@ function apply!(o::RADAM, x, Δ)
221221
end :: Tuple{typeof(x),typeof(x),Vector{Float64},Ref{Int}}
222222

223223
@. mt = β[1] * mt + (1 - β[1]) * Δ
224-
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
224+
@. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ)
225225
ρ = ρ∞ - 2t[] * βp[2] / (1 - βp[2])
226226
if ρ > 4
227227
r = sqrt((ρ-4)*-2)*ρ∞/((ρ∞-4)*(ρ∞-2)*ρ))
@@ -311,7 +311,7 @@ function apply!(o::OADAM, x, Δ)
311311
end :: Tuple{typeof(x),typeof(x),typeof(x),Vector{Float64}}
312312

313313
@. mt = β[1] * mt + (1 - β[1]) * Δ
314-
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
314+
@. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ)
315315
@. Δ = -Δ_
316316
@. Δ_ = η * mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + ϵ)
317317
@. Δ += 2Δ_
@@ -348,7 +348,7 @@ ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
348348
function apply!(o::ADAGrad, x, Δ)
349349
η = o.eta
350350
acc = get!(() -> fill!(similar(x), ϵ), o.acc, x)::typeof(x)
351-
@. acc += Δ^2
351+
@. acc += Δ * conj(Δ)
352352
@. Δ *= η / (acc + ϵ)
353353
end
354354

@@ -379,11 +379,11 @@ ADADelta(ρ = 0.9) = ADADelta(ρ, IdDict())
379379
function apply!(o::ADADelta, x, Δ)
380380
ρ = o.rho
381381
acc, Δacc = get!(() -> (zero(x), zero(x)), o.state, x)::NTuple{2,typeof(x)}
382-
@. acc = ρ * acc + (1 - ρ) * Δ^2
382+
@. acc = ρ * acc + (1 - ρ) * Δ * conj(Δ)
383383
# DON'T remove epsilon from numerator
384384
# or even out of the square roots
385385
@. Δ *= (Δacc + ϵ) / (acc + ϵ)
386-
@. Δacc = ρ * Δacc + (1 - ρ) * Δ^2
386+
@. Δacc = ρ * Δacc + (1 - ρ) * Δ * conj(Δ)
387387
return Δ
388388
end
389389

@@ -463,7 +463,7 @@ function apply!(o::NADAM, x, Δ)
463463
β1p, β2p = βp
464464

465465
@. mt = β[1] * mt + (1 - β[1]) * Δ
466-
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
466+
@. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ)
467467
@. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / ((vt * β[2] / (1 - β2p)) + ϵ) * η
468468
βp .= βp .* β
469469

@@ -524,7 +524,7 @@ function apply!(o::AdaBelief, x, Δ)
524524
η, β = o.eta, o.beta
525525
mt, st = get!(() -> (zero(x), zero(x)), o.state, x)::NTuple{2,typeof(x)}
526526
@. mt = β[1] * mt + (1 - β[1]) * Δ
527-
@. st = β[2] * st + (1 - β[2]) *- mt)^2
527+
@. st = β[2] * st + (1 - β[2]) *- mt) * conj- mt)
528528
@. Δ = η * mt / ((st) + ϵ)
529529
return Δ
530530
end

test/losses.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ y = [1, 1, 0, 0]
3939

4040
@testset "mse" begin
4141
@test mse(ŷ, y) (.1^2 + .9^2)/2
42+
43+
# Test that mse() loss works on complex values:
44+
@test mse(0 + 0im, 1 + 1im) == 2
4245
end
4346

4447
@testset "mae" begin

test/optimise.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,40 @@ end
190190
Flux.update!(opt, θ, gs)
191191
@test w wold .- 0.1
192192
end
193+
194+
# Flux PR #1776
195+
# We need to test that optimisers like ADAM that maintain an internal momentum
196+
# estimate properly calculate the second-order statistics on the gradients as
197+
# the flow backward through the model. Previously, we would calculate second-
198+
# order statistics via `Δ^2` rather than the complex-aware `Δ * conj(Δ)`, which
199+
# wreaks all sorts of havoc on our training loops. This test ensures that
200+
# a simple optimization is montonically decreasing (up to learning step effects)
201+
@testset "Momentum Optimisers and complex values" begin
202+
# Test every optimizer that has momentum internally
203+
for opt_ctor in [ADAM, RMSProp, RADAM, OADAM, ADAGrad, ADADelta, NADAM, AdaBelief]
204+
# Our "model" is just a complex number
205+
w = zeros(ComplexF32, 1)
206+
207+
# Our model attempts to learn `f(x) = conj(x)` where `f(x) = w*x`
208+
function loss()
209+
# Deterministic training data is the best training data
210+
x = ones(1, 1) + 1im*ones(1, 1)
211+
212+
# Manually implement `mse()` to allow demonstration of brokenness
213+
# on older Flux builds that don't have a fixed `mse()`
214+
return sum(abs2.(w * x .- conj(x)))
215+
end
216+
217+
params = Flux.Params([w])
218+
opt = opt_ctor(1e-2)
219+
220+
# Train for 10 iterations, enforcing that loss is monotonically decreasing
221+
last_loss = Inf
222+
for idx in 1:10
223+
grads = Flux.gradient(loss, params)
224+
@test loss() < last_loss
225+
last_loss = loss()
226+
Flux.update!(opt, params, grads)
227+
end
228+
end
229+
end

0 commit comments

Comments
 (0)