Skip to content

Commit adc0e85

Browse files
authored
Complex numbers alla Flux 1776 (#47)
* change broadcasting macro & remove bugs * fix OADAM * log the loss during tests * complex numbers alla Flux 1776 * fixed? * fix Momentum, Nesterov * found the bug, fixed * rm plotting code * uncomment one test * comments
1 parent a66f9a5 commit adc0e85

File tree

2 files changed

+81
-17
lines changed

2 files changed

+81
-17
lines changed

src/rules.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ init(o::RMSProp, x::AbstractArray) = zero(x)
103103
function apply!(o::RMSProp, state, x, dx)
104104
η, ρ, ϵ, acc = o.eta, o.rho, o.epsilon, state
105105

106-
@.. acc = ρ * acc + (1 - ρ) * dx^2
106+
@.. acc = ρ * acc + (1 - ρ) * abs2(dx)
107107
dx′ = @lazy dx */ (sqrt(acc) + ϵ))
108108

109109
return acc, dx′
@@ -136,7 +136,7 @@ function apply!(o::ADAM, state, x, dx)
136136
mt, vt, βt = state
137137

138138
@.. mt = β[1] * mt + (1 - β[1]) * dx
139-
@.. vt = β[2] * vt + (1 - β[2]) * dx ^ 2
139+
@.. vt = β[2] * vt + (1 - β[2]) * abs2(dx)
140140
dx′ = @lazy mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) * η
141141

142142
return (mt, vt, βt .* β), dx′
@@ -171,7 +171,7 @@ function apply!(o::RADAM, state, x, dx)
171171
mt, vt, βt, t = state
172172

173173
@.. mt = β[1] * mt + (1 - β[1]) * dx
174-
@.. vt = β[2] * vt + (1 - β[2]) * dx^2
174+
@.. vt = β[2] * vt + (1 - β[2]) * abs2(dx)
175175
ρ = ρ∞ - 2*t * βt[2] / (1 - βt[2])
176176
if ρ > 4
177177
r = sqrt((ρ - 4) *- 2) * ρ∞/((ρ∞ - 4) * (ρ∞ - 2) * ρ))
@@ -244,7 +244,7 @@ function apply!(o::OADAM, state, x, dx)
244244
mt, vt, βt, term = state
245245

246246
@.. mt = β[1] * mt + (1 - β[1]) * dx
247-
@.. vt = β[2] * vt + (1 - β[2]) * dx^2
247+
@.. vt = β[2] * vt + (1 - β[2]) * abs2(dx)
248248
prev = copy(term)
249249
@.. term = η * mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ)
250250
dx′ = @lazy 2 * term - prev
@@ -277,7 +277,7 @@ function apply!(o::ADAGrad, state, x, dx)
277277
η, ϵ = o.eta, o.epsilon
278278
acc = state
279279

280-
@.. acc = acc + dx^2
280+
@.. acc = acc + abs2(dx)
281281
dx′ = @lazy dx * η / (sqrt(acc) + ϵ)
282282

283283
return acc, dx′
@@ -307,10 +307,10 @@ function apply!(o::ADADelta, state, x, dx)
307307
ρ, ϵ = o.rho, o.epsilon
308308
acc, Δacc = state
309309

310-
@.. acc = ρ * acc + (1 - ρ) * dx^2
310+
@.. acc = ρ * acc + (1 - ρ) * abs2(dx)
311311
# DON'T remove epsilon from numerator or even out of the square roots!
312312
dx′ = @. dx * sqrt(Δacc + ϵ) / sqrt(acc + ϵ) # Cannot be lazy as this needs the old Δacc
313-
@.. Δacc = ρ * Δacc + (1 - ρ) * dx′^2
313+
@.. Δacc = ρ * Δacc + (1 - ρ) * abs2(dx′)
314314

315315
return (acc, Δacc), dx′
316316
end
@@ -344,7 +344,7 @@ function apply!(o::AMSGrad, state, x, dx)
344344
mt, vt, v̂t = state
345345

346346
@.. mt = β[1] * mt + (1 - β[1]) * dx
347-
@.. vt = β[2] * vt + (1 - β[2]) * dx ^ 2
347+
@.. vt = β[2] * vt + (1 - β[2]) * abs2(dx)
348348
@.. v̂t = max(v̂t, vt)
349349
dx′ = @lazy η * mt / (sqrt(v̂t) + ϵ)
350350

@@ -380,7 +380,7 @@ function apply!(o::NADAM, state, x, dx)
380380
mt, vt, βt = state
381381

382382
@.. mt = β[1] * mt + (1 - β[1]) * dx
383-
@.. vt = β[2] * vt + (1 - β[2]) * dx^2
383+
@.. vt = β[2] * vt + (1 - β[2]) * abs2(dx)
384384
dx′ = @lazy (β[1] * mt / (1 - β[1] * βt[1]) + (1 - β[1]) * dx / (1 - βt[1])) /
385385
(sqrt(vt * β[2] / (1 - βt[2])) + ϵ) * η
386386

@@ -433,7 +433,7 @@ function apply!(o::AdaBelief, state, x, dx)
433433
mt, st = state
434434

435435
@.. mt = β[1] * mt + (1 - β[1]) * dx
436-
@.. st = β[2] * st + (1 - β[2]) * (dx - mt)^2
436+
@.. st = β[2] * st + (1 - β[2]) * abs2(dx - mt)
437437
dx′ = @lazy η * mt / (sqrt(st) + ϵ)
438438

439439
return (mt, st), dx′

test/rules.jl

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,19 @@ RULES = [
1616
OptimiserChain(WeightDecay(), OADAM(), ClipGrad(1)),
1717
]
1818

19-
name(o) = typeof(o).name.name
19+
name(o) = typeof(o).name.name # just for printing testset headings
2020
name(o::OptimiserChain) = join(name.(o.opts), "")
2121

22+
LOG = Dict() # for debugging these testsets, this makes it easy to plot each optimiser's loss
23+
24+
loggradient(o) = (f, xs...) -> begin
25+
y, dxs = Zygote.withgradient(f, xs...)
26+
push!(get!(() -> Float32[], LOG, name(o)), y)
27+
dxs # save the loss, return the gradient
28+
end
29+
2230
@testset "independence" begin
31+
empty!(LOG)
2332
@testset "$(name(o))" for o in RULES
2433
w = randn(10, 10)
2534
w′ = randn(10, 10)
@@ -28,22 +37,23 @@ name(o::OptimiserChain) = join(name.(o.opts), " → ")
2837
st = Optimisers.setup(o, w)
2938
for t = 1:10^5
3039
x = rand(10)
31-
gs = gradient(w -> iloss(x, w, w′), w)
40+
gs = loggradient(o)(w -> iloss(x, w, w′), w)
3241
st, w = Optimisers.update!(st, w, gs...)
3342
end
3443
@test iloss(rand(10, 10), w, w′) < 0.01
3544
end
3645
end
3746

3847
@testset verbose=true "simple sum" begin
48+
empty!(LOG)
3949
@testset "$(name(o))" for o in RULES
4050
m = shuffle!(reshape(1:64, 8, 8) .+ 0.0)
4151
s = Optimisers.setup(o, m)
4252
for _ in 1:10^5
43-
g = gradient(x -> sum(abs2, x + x'), m)[1]
53+
g = loggradient(o)(x -> sum(abs2, x + x'), m)[1]
4454
s, m = Optimisers.update!(s, m, g)
4555
end
46-
# @test sum(m) < sum(1:64)
56+
@test sum(m) < sum(1:64)
4757
if sum(m) < 1
4858
@test sum(m) < 1
4959
else
@@ -54,21 +64,23 @@ end
5464
end
5565

5666
@testset "original" begin
67+
empty!(LOG)
5768
@testset "$(name(o))" for o in RULES
5869
w′ == rand(3, 3), β = rand(3, 3))
5970
w == 5rand(3, 3), β = rand(3, 3))
6071
st = Optimisers.setup(o, w)
6172
loss(x, y) = mean((x.α .* x.β .- y.α .* y.β) .^ 2)
6273
@test loss(w, w′) > 1
6374
for i = 1:10^4
64-
gs = gradient(x -> loss(x, w′), w)
75+
gs = loggradient(o)(x -> loss(x, w′), w)
6576
st, w = Optimisers.update(st, w, gs...)
6677
end
6778
@test loss(w, w′) < 0.001
6879
end
6980
end
7081

7182
@testset verbose=true "StaticArrays" begin
83+
empty!(LOG)
7284
@testset "$(name(o))" for o in RULES
7385
W1 = @SMatrix randn(10, 10)
7486
b1 = @SVector randn(10)
@@ -82,7 +94,7 @@ end
8294
@test s_loss(model, x, y) > 10
8395
state = Optimisers.setup(o, model)
8496
for t = 1:10^3
85-
g = gradient(m -> s_loss(m, x, y), model)[1]
97+
g = loggradient(o)(m -> s_loss(m, x, y), model)[1]
8698
state, model = Optimisers.update!(state, model, g)
8799
end
88100
if o isa Descent
@@ -94,7 +106,7 @@ end
94106
end
95107
end
96108

97-
@testset verbose=true "element types" begin
109+
@testset "element types" begin
98110
@testset "$(name(o))" for o in RULES
99111
marray = (Float32[1,2], Float64[3,4], Float16[5,6])
100112
types = map(eltype, marray)
@@ -166,3 +178,55 @@ end
166178
end
167179
end
168180

181+
@testset "with complex numebers: Flux#1776" begin
182+
empty!(LOG)
183+
@testset "$(name(opt))" for opt in [
184+
# The Flux PR had 1e-2 for all. But ADADelta(ρ) needs ρ≈0.9 not small. And it helps to make ε not too small too:
185+
ADAM(1e-2), RMSProp(1e-2), RADAM(1e-2), OADAM(1e-2), ADAGrad(1e-2), ADADelta(0.9, 1e-5), NADAM(1e-2), AdaBelief(1e-2),
186+
# These weren't in Flux PR:
187+
Descent(1e-2), Momentum(1e-2), Nesterov(1e-2), ADAMW(1e-2),
188+
]
189+
# Our "model" is just a complex number
190+
model = (w = zeros(ComplexF64, 1),)
191+
192+
# Our model attempts to learn `f(x) = conj(x)` where `f(x) = w*x`
193+
function loss(m)
194+
# Deterministic training data is the best training data
195+
x = ones(1, 1) + 1im*ones(1, 1)
196+
# Manually implement `mse()` to allow demonstration of brokenness
197+
# on older Flux builds that don't have a fixed `mse()`
198+
return sum(abs2.(m.w * x .- conj(x)))
199+
end
200+
@test loss(model) 2.0
201+
202+
state = Optimisers.setup(opt, model)
203+
204+
# Train for 10 iterations, enforcing that loss is monotonically decreasing
205+
last_loss = Inf
206+
for idx in 1:10
207+
grads = loggradient(opt)(loss, model)
208+
state, model = Optimisers.update!(state, model, grads...)
209+
opt isa Union{Momentum, Nesterov} && idx > 8 && continue # these are very flat at the end
210+
@test loss(model) < last_loss
211+
last_loss = loss(model)
212+
end
213+
@test loss(model) < 1.9
214+
215+
# Repeat with StaticArrays
216+
static_model = (w = SA[0.0 + 0im],)
217+
static_state = Optimisers.setup(opt, static_model)
218+
function static_loss(m)
219+
x = hcat(SA[1.0 + im])
220+
sum(abs2.(m.w * x .- conj(x)))
221+
end
222+
last_loss = Inf
223+
for idx in 1:10
224+
grads = gradient(static_loss, static_model)
225+
static_state, static_model = Optimisers.update!(static_state, static_model, grads...)
226+
opt isa Union{Momentum, Nesterov} && idx > 8 && continue
227+
@test static_loss(static_model) < last_loss
228+
last_loss = static_loss(static_model)
229+
end
230+
@test static_loss(static_model) < 1.9
231+
end
232+
end

0 commit comments

Comments
 (0)