|
190 | 190 | Flux.update!(opt, θ, gs)
|
191 | 191 | @test w ≈ wold .- 0.1
|
192 | 192 | end
|
| 193 | + |
| 194 | +# Flux PR #1776 |
| 195 | +# We need to test that optimisers like ADAM that maintain an internal momentum |
| 196 | +# estimate properly calculate the second-order statistics on the gradients as |
| 197 | +# the flow backward through the model. Previously, we would calculate second- |
| 198 | +# order statistics via `Δ^2` rather than the complex-aware `Δ * conj(Δ)`, which |
| 199 | +# wreaks all sorts of havoc on our training loops. This test ensures that |
| 200 | +# a simple optimization is montonically decreasing (up to learning step effects) |
| 201 | +@testset "Momentum Optimisers and complex values" begin |
| 202 | + # Test every optimizer that has momentum internally |
| 203 | + for opt_ctor in [ADAM, RMSProp, RADAM, OADAM, ADAGrad, ADADelta, NADAM, AdaBelief] |
| 204 | + # Our "model" is just a complex number |
| 205 | + w = zeros(ComplexF32, 1) |
| 206 | + |
| 207 | + # Our model attempts to learn `f(x) = conj(x)` where `f(x) = w*x` |
| 208 | + function loss() |
| 209 | + # Deterministic training data is the best training data |
| 210 | + x = ones(1, 1) + 1im*ones(1, 1) |
| 211 | + |
| 212 | + # Manually implement `mse()` to allow demonstration of brokenness |
| 213 | + # on older Flux builds that don't have a fixed `mse()` |
| 214 | + return sum(abs2.(w * x .- conj(x))) |
| 215 | + end |
| 216 | + |
| 217 | + params = Flux.Params([w]) |
| 218 | + opt = opt_ctor(1e-2) |
| 219 | + |
| 220 | + # Train for 10 iterations, enforcing that loss is monotonically decreasing |
| 221 | + last_loss = Inf |
| 222 | + for idx in 1:10 |
| 223 | + grads = Flux.gradient(loss, params) |
| 224 | + @test loss() < last_loss |
| 225 | + last_loss = loss() |
| 226 | + Flux.update!(opt, params, grads) |
| 227 | + end |
| 228 | + end |
| 229 | +end |
0 commit comments