@@ -16,10 +16,19 @@ RULES = [
16
16
OptimiserChain (WeightDecay (), OADAM (), ClipGrad (1 )),
17
17
]
18
18
19
- name (o) = typeof (o). name. name
19
+ name (o) = typeof (o). name. name # just for printing testset headings
20
20
name (o:: OptimiserChain ) = join (name .(o. opts), " → " )
21
21
22
+ LOG = Dict () # for debugging these testsets, this makes it easy to plot each optimiser's loss
23
+
24
+ loggradient (o) = (f, xs... ) -> begin
25
+ y, dxs = Zygote. withgradient (f, xs... )
26
+ push! (get! (() -> Float32[], LOG, name (o)), y)
27
+ dxs # save the loss, return the gradient
28
+ end
29
+
22
30
@testset " independence" begin
31
+ empty! (LOG)
23
32
@testset " $(name (o)) " for o in RULES
24
33
w = randn (10 , 10 )
25
34
w′ = randn (10 , 10 )
@@ -28,22 +37,23 @@ name(o::OptimiserChain) = join(name.(o.opts), " → ")
28
37
st = Optimisers. setup (o, w)
29
38
for t = 1 : 10 ^ 5
30
39
x = rand (10 )
31
- gs = gradient (w -> iloss (x, w, w′), w)
40
+ gs = loggradient (o) (w -> iloss (x, w, w′), w)
32
41
st, w = Optimisers. update! (st, w, gs... )
33
42
end
34
43
@test iloss (rand (10 , 10 ), w, w′) < 0.01
35
44
end
36
45
end
37
46
38
47
@testset verbose= true " simple sum" begin
48
+ empty! (LOG)
39
49
@testset " $(name (o)) " for o in RULES
40
50
m = shuffle! (reshape (1 : 64 , 8 , 8 ) .+ 0.0 )
41
51
s = Optimisers. setup (o, m)
42
52
for _ in 1 : 10 ^ 5
43
- g = gradient (x -> sum (abs2, x + x' ), m)[1 ]
53
+ g = loggradient (o) (x -> sum (abs2, x + x' ), m)[1 ]
44
54
s, m = Optimisers. update! (s, m, g)
45
55
end
46
- # @test sum(m) < sum(1:64)
56
+ @test sum (m) < sum (1 : 64 )
47
57
if sum (m) < 1
48
58
@test sum (m) < 1
49
59
else
54
64
end
55
65
56
66
@testset " original" begin
67
+ empty! (LOG)
57
68
@testset " $(name (o)) " for o in RULES
58
69
w′ = (α = rand (3 , 3 ), β = rand (3 , 3 ))
59
70
w = (α = 5 rand (3 , 3 ), β = rand (3 , 3 ))
60
71
st = Optimisers. setup (o, w)
61
72
loss (x, y) = mean ((x. α .* x. β .- y. α .* y. β) .^ 2 )
62
73
@test loss (w, w′) > 1
63
74
for i = 1 : 10 ^ 4
64
- gs = gradient (x -> loss (x, w′), w)
75
+ gs = loggradient (o) (x -> loss (x, w′), w)
65
76
st, w = Optimisers. update (st, w, gs... )
66
77
end
67
78
@test loss (w, w′) < 0.001
68
79
end
69
80
end
70
81
71
82
@testset verbose= true " StaticArrays" begin
83
+ empty! (LOG)
72
84
@testset " $(name (o)) " for o in RULES
73
85
W1 = @SMatrix randn (10 , 10 )
74
86
b1 = @SVector randn (10 )
82
94
@test s_loss (model, x, y) > 10
83
95
state = Optimisers. setup (o, model)
84
96
for t = 1 : 10 ^ 3
85
- g = gradient (m -> s_loss (m, x, y), model)[1 ]
97
+ g = loggradient (o) (m -> s_loss (m, x, y), model)[1 ]
86
98
state, model = Optimisers. update! (state, model, g)
87
99
end
88
100
if o isa Descent
94
106
end
95
107
end
96
108
97
- @testset verbose = true " element types" begin
109
+ @testset " element types" begin
98
110
@testset " $(name (o)) " for o in RULES
99
111
marray = (Float32[1 ,2 ], Float64[3 ,4 ], Float16[5 ,6 ])
100
112
types = map (eltype, marray)
166
178
end
167
179
end
168
180
181
+ @testset " with complex numebers: Flux#1776" begin
182
+ empty! (LOG)
183
+ @testset " $(name (opt)) " for opt in [
184
+ # The Flux PR had 1e-2 for all. But ADADelta(ρ) needs ρ≈0.9 not small. And it helps to make ε not too small too:
185
+ ADAM (1e-2 ), RMSProp (1e-2 ), RADAM (1e-2 ), OADAM (1e-2 ), ADAGrad (1e-2 ), ADADelta (0.9 , 1e-5 ), NADAM (1e-2 ), AdaBelief (1e-2 ),
186
+ # These weren't in Flux PR:
187
+ Descent (1e-2 ), Momentum (1e-2 ), Nesterov (1e-2 ), ADAMW (1e-2 ),
188
+ ]
189
+ # Our "model" is just a complex number
190
+ model = (w = zeros (ComplexF64, 1 ),)
191
+
192
+ # Our model attempts to learn `f(x) = conj(x)` where `f(x) = w*x`
193
+ function loss (m)
194
+ # Deterministic training data is the best training data
195
+ x = ones (1 , 1 ) + 1im * ones (1 , 1 )
196
+ # Manually implement `mse()` to allow demonstration of brokenness
197
+ # on older Flux builds that don't have a fixed `mse()`
198
+ return sum (abs2 .(m. w * x .- conj (x)))
199
+ end
200
+ @test loss (model) ≈ 2.0
201
+
202
+ state = Optimisers. setup (opt, model)
203
+
204
+ # Train for 10 iterations, enforcing that loss is monotonically decreasing
205
+ last_loss = Inf
206
+ for idx in 1 : 10
207
+ grads = loggradient (opt)(loss, model)
208
+ state, model = Optimisers. update! (state, model, grads... )
209
+ opt isa Union{Momentum, Nesterov} && idx > 8 && continue # these are very flat at the end
210
+ @test loss (model) < last_loss
211
+ last_loss = loss (model)
212
+ end
213
+ @test loss (model) < 1.9
214
+
215
+ # Repeat with StaticArrays
216
+ static_model = (w = SA[0.0 + 0im ],)
217
+ static_state = Optimisers. setup (opt, static_model)
218
+ function static_loss (m)
219
+ x = hcat (SA[1.0 + im])
220
+ sum (abs2 .(m. w * x .- conj (x)))
221
+ end
222
+ last_loss = Inf
223
+ for idx in 1 : 10
224
+ grads = gradient (static_loss, static_model)
225
+ static_state, static_model = Optimisers. update! (static_state, static_model, grads... )
226
+ opt isa Union{Momentum, Nesterov} && idx > 8 && continue
227
+ @test static_loss (static_model) < last_loss
228
+ last_loss = static_loss (static_model)
229
+ end
230
+ @test static_loss (static_model) < 1.9
231
+ end
232
+ end
0 commit comments