You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
1776: Use conjugates in optimizers to better learn on complex-valued inputs r=DhairyaLGandhi a=staticfloat
When weights are complex, the deltas to them will also be complex. In
all optimizers that need a second-order estimate of gradient statistics,
we generally want to use the `x * conj(x)` pattern, rather than `x^2`.
We can see the effect this has on ADAM with the following test:
```julia
begin
# This model will learn `W = I` and `bias = 0`
complex_init(dims...) = Flux.glorot_uniform(dims...) .+ 1im .* Flux.glorot_uniform(dims...)
model = Chain(
Dense(4, 4, tanh; init=complex_init),
Dense(4, 16, tanh; init=complex_init),
Dense(16, 4, tanh; init=complex_init),
Dense(4, 4, tanh; init=complex_init),
)
# Loss function; note we don't need the `abs()` if we update `Flux.Losses.mse()` as below
function loss(x)
return abs.(Flux.Losses.mse(model(x), x))
end
# Keep track of loss from epoch to epoch
losses = Float64[]
dataset = [(randn(ComplexF32, 4, 10),)]
params = Flux.params(model)
opt = Flux.Optimise.ADAM(0.001)
for epoch_idx in 1:10000
Flux.train!(loss, params, dataset, opt)
epoch_loss = loss(dataset[1][1])
push!(losses, epoch_loss)
if epoch_idx % 100 == 0
`@info("epoch` done", epoch_idx, epoch_loss)
end
end
# Plot the loss
fig = Figure()
meta_ax = Axis(fig[1,1])
lines!(meta_ax, log.(losses); label="Training loss")
fig[1,2] = Legend(fig, meta_ax, "Learning Stats")
fig
end
```
The training loss before the fix looks like this:

Whereas after both of these commits, it looks like this:

Note that while the absolute value of the loss is actually comparable in this simple example, the loss landscape is significantly more chaotic. With a higher learning rate, the "fixed" version is able to learn much faster:

Whereas the unfixed version simply diverges:

Co-authored-by: Elliot Saba <staticfloat@gmail.com>
0 commit comments