@@ -5,6 +5,8 @@ using Optimisers: @.., @lazy
5
5
6
6
Random. seed! (1 )
7
7
8
+ # Fake "models" for testing
9
+
8
10
struct Foo; x; y; end
9
11
Functors. @functor Foo
10
12
Optimisers. trainable (x:: Foo ) = (x. y, x. x)
@@ -16,6 +18,8 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
16
18
mutable struct MutTwo; x; y; end
17
19
Functors. @functor MutTwo
18
20
21
+ # Simple rules for testing
22
+
19
23
struct DummyHigherOrder <: AbstractRule end
20
24
Optimisers. init (:: DummyHigherOrder , x:: AbstractArray ) =
21
25
(ones (eltype (x), size (x)), zero (x))
227
231
@test_throws MethodError Optimisers. update (sm, m)
228
232
end
229
233
230
- @testset " 2nd order gradient" begin
231
- m = (α = ([1.0 ], sin), γ = Float32[4 ,3 ,2 ])
232
-
233
- # Special rule which requires this:
234
- s = Optimisers. setup (BiRule (), m)
235
- g = (α = ([0.1 ], ZeroTangent ()), γ = [1 ,10 ,100 ],)
236
- s1, m1 = Optimisers. update (s, m, g, g)
237
- @test m1. α[1 ] == [0.9 ]
238
- @test_throws Exception Optimisers. update (s, m, g, map (x-> 2 .* x, g))
239
-
240
- # Ordinary rule which doesn't need it:
241
- s2 = Optimisers. setup (Adam (), m)
242
- s3, m3 = Optimisers. update (s2, m, g)
243
- s4, m4 = Optimisers. update (s2, m, g, g)
244
- @test m3. γ == m4. γ
245
- end
246
-
247
234
@testset " broadcasting macros" begin
248
235
x = [1.0 , 2.0 ]; y = [3 ,4 ]; z = [5 ,6 ]
249
236
@test (@lazy x + y * z) isa Broadcast. Broadcasted
@@ -365,34 +352,53 @@ end
365
352
@test model2. a === model2. b # tie of MutTwo structs is restored
366
353
@test model2. a != = model2. c # but a new tie is not created
367
354
end
368
- end
355
+ end # tied weights
356
+
357
+ @testset " 2nd-order interface" begin
358
+ @testset " BiRule" begin
359
+ m = (α = ([1.0 ], sin), γ = Float32[4 ,3 ,2 ])
360
+
361
+ # Special rule which requires this:
362
+ s = Optimisers. setup (BiRule (), m)
363
+ g = (α = ([0.1 ], ZeroTangent ()), γ = [1 ,10 ,100 ],)
364
+ s1, m1 = Optimisers. update (s, m, g, g)
365
+ @test m1. α[1 ] == [0.9 ]
366
+ @test_throws Exception Optimisers. update (s, m, g, map (x-> 2 .* x, g))
367
+
368
+ # Ordinary rule which doesn't need it:
369
+ s2 = Optimisers. setup (Adam (), m)
370
+ s3, m3 = Optimisers. update (s2, m, g)
371
+ s4, m4 = Optimisers. update (s2, m, g, g)
372
+ @test m3. γ == m4. γ
373
+ end
369
374
370
- @testset " higher order interface" begin
371
- w, b = rand (3 , 4 ), rand (3 )
372
-
373
- o = DummyHigherOrder ()
374
- psin = (w, b)
375
- dxs = map (x -> rand (size (x)... ), psin)
376
- dx2s = map (x -> rand (size (x)... ), psin)
377
- stin = Optimisers. setup (o, psin)
378
- stout, psout = Optimisers. update (stin, psin, dxs, dx2s)
379
-
380
- # hardcoded rule behavior for dummy rule
381
- @test psout[1 ] == dummy_update_rule (stin[1 ]. state, psin[1 ], dxs[1 ], dx2s[1 ])
382
- @test psout[2 ] == dummy_update_rule (stin[2 ]. state, psin[2 ], dxs[2 ], dx2s[2 ])
383
- @test stout[1 ]. state[1 ] == stin[1 ]. state[1 ] .+ 1
384
- @test stout[2 ]. state[2 ] == stin[2 ]. state[2 ] .+ 1
385
-
386
- # error if only given one derivative
387
- @test_throws MethodError Optimisers. update (stin, psin, dxs)
388
-
389
- # first-order rules compose with second-order
390
- ochain = OptimiserChain (Descent (0.1 ), o)
391
- stin = Optimisers. setup (ochain, psin)
392
- stout, psout = Optimisers. update (stin, psin, dxs, dx2s)
393
- @test psout[1 ] == dummy_update_rule (stin[1 ]. state[2 ], psin[1 ], 0.1 * dxs[1 ], dx2s[1 ])
394
- @test psout[2 ] == dummy_update_rule (stin[2 ]. state[2 ], psin[2 ], 0.1 * dxs[2 ], dx2s[2 ])
395
- end
375
+ @testset " DummyHigherOrder" begin
376
+ w, b = rand (3 , 4 ), rand (3 )
377
+
378
+ o = DummyHigherOrder ()
379
+ psin = (w, b)
380
+ dxs = map (x -> rand (size (x)... ), psin)
381
+ dx2s = map (x -> rand (size (x)... ), psin)
382
+ stin = Optimisers. setup (o, psin)
383
+ stout, psout = Optimisers. update (stin, psin, dxs, dx2s)
384
+
385
+ # hardcoded rule behavior for dummy rule
386
+ @test psout[1 ] == dummy_update_rule (stin[1 ]. state, psin[1 ], dxs[1 ], dx2s[1 ])
387
+ @test psout[2 ] == dummy_update_rule (stin[2 ]. state, psin[2 ], dxs[2 ], dx2s[2 ])
388
+ @test stout[1 ]. state[1 ] == stin[1 ]. state[1 ] .+ 1
389
+ @test stout[2 ]. state[2 ] == stin[2 ]. state[2 ] .+ 1
390
+
391
+ # error if only given one derivative
392
+ @test_throws MethodError Optimisers. update (stin, psin, dxs)
393
+
394
+ # first-order rules compose with second-order
395
+ ochain = OptimiserChain (Descent (0.1 ), o)
396
+ stin = Optimisers. setup (ochain, psin)
397
+ stout, psout = Optimisers. update (stin, psin, dxs, dx2s)
398
+ @test psout[1 ] == dummy_update_rule (stin[1 ]. state[2 ], psin[1 ], 0.1 * dxs[1 ], dx2s[1 ])
399
+ @test psout[2 ] == dummy_update_rule (stin[2 ]. state[2 ], psin[2 ], 0.1 * dxs[2 ], dx2s[2 ])
400
+ end
401
+ end # 2nd-order
396
402
397
403
end
398
404
@testset verbose= true " Destructure" begin
0 commit comments