Skip to content

Commit 8c3d852

Browse files
committed
Add tests for complex valued training
1 parent 9326702 commit 8c3d852

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

test/losses.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ y = [1, 1, 0, 0]
3939

4040
@testset "mse" begin
4141
@test mse(ŷ, y) (.1^2 + .9^2)/2
42+
43+
# Test that mse() loss works on complex values:
44+
@test mse(0 + 0im, 1 + 1im) == 2
4245
end
4346

4447
@testset "mae" begin

test/optimise.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,40 @@ end
190190
Flux.update!(opt, θ, gs)
191191
@test w wold .- 0.1
192192
end
193+
194+
# Flux PR #1776
195+
# We need to test that optimisers like ADAM that maintain an internal momentum
196+
# estimate properly calculate the second-order statistics on the gradients as
197+
# the flow backward through the model. Previously, we would calculate second-
198+
# order statistics via `Δ^2` rather than the complex-aware `Δ * conj(Δ)`, which
199+
# wreaks all sorts of havoc on our training loops. This test ensures that
200+
# a simple optimization is montonically decreasing (up to learning step effects)
201+
@testset "Momentum Optimisers and complex values" begin
202+
# Test every optimizer that has momentum internally
203+
for opt_ctor in [ADAM, RMSProp, RADAM, OADAM, ADAGrad, ADADelta, NADAM, AdaBelief]
204+
# Our "model" is just a complex number
205+
w = zeros(ComplexF32, 1)
206+
207+
# Our model attempts to learn `f(x) = conj(x)` where `f(x) = w*x`
208+
function loss()
209+
# Deterministic training data is the best training data
210+
x = ones(1, 1) + 1im*ones(1, 1)
211+
212+
# Manually implement `mse()` to allow demonstration of brokenness
213+
# on older Flux builds that don't have a fixed `mse()`
214+
return sum(abs2.(w * x .- conj(x)))
215+
end
216+
217+
params = Flux.Params([w])
218+
opt = opt_ctor(1e-2)
219+
220+
# Train for 10 iterations, enforcing that loss is monotonically decreasing
221+
last_loss = Inf
222+
for idx in 1:10
223+
grads = Flux.gradient(loss, params)
224+
@test loss() < last_loss
225+
last_loss = loss()
226+
Flux.update!(opt, params, grads)
227+
end
228+
end
229+
end

0 commit comments

Comments
 (0)