@@ -13,6 +13,19 @@ struct TwoThirds a; b; c; end
13
13
Functors. @functor TwoThirds (a, c)
14
14
Optimisers. trainable (x:: TwoThirds ) = (a = x. a,)
15
15
16
+ struct DummyHigherOrder <: AbstractRule end
17
+
18
+ Optimisers. init (:: DummyHigherOrder , x:: AbstractArray ) =
19
+ (ones (eltype (x), size (x)), zero (x))
20
+
21
+ dummy_update_rule (st, p, dx, dx2) = @. p - (st[1 ] * dx + st[2 ] * dx2)
22
+ function Optimisers. apply! (:: DummyHigherOrder , state, x, dx, dx2)
23
+ a, b = state
24
+ @. . dx = a * dx + b * dx2
25
+
26
+ return (a .+ 1 , b .+ 1 ), dx
27
+ end
28
+
16
29
@testset verbose= true " Optimisers.jl" begin
17
30
@testset verbose= true " Features" begin
18
31
@@ -220,6 +233,33 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
220
233
@test_throws ArgumentError Optimisers. setup (AdamW (), m2)
221
234
end
222
235
236
+ @testset " higher order interface" begin
237
+ w, b = rand (3 , 4 ), rand (3 )
238
+
239
+ o = DummyHigherOrder ()
240
+ psin = (w, b)
241
+ dxs = map (x -> rand (size (x)... ), psin)
242
+ dx2s = map (x -> rand (size (x)... ), psin)
243
+ stin = Optimisers. setup (o, psin)
244
+ stout, psout = Optimisers. update (stin, psin, dxs, dx2s)
245
+
246
+ # hardcoded rule behavior for dummy rule
247
+ @test psout[1 ] == dummy_update_rule (stin[1 ]. state, psin[1 ], dxs[1 ], dx2s[1 ])
248
+ @test psout[2 ] == dummy_update_rule (stin[2 ]. state, psin[2 ], dxs[2 ], dx2s[2 ])
249
+ @test stout[1 ]. state[1 ] == stin[1 ]. state[1 ] .+ 1
250
+ @test stout[2 ]. state[2 ] == stin[2 ]. state[2 ] .+ 1
251
+
252
+ # error if only given one derivative
253
+ @test_throws MethodError Optimisers. update (stin, psin, dxs)
254
+
255
+ # first-order rules compose with second-order
256
+ ochain = OptimiserChain (Descent (0.1 ), o)
257
+ stin = Optimisers. setup (ochain, psin)
258
+ stout, psout = Optimisers. update (stin, psin, dxs, dx2s)
259
+ @test psout[1 ] == dummy_update_rule (stin[1 ]. state[2 ], psin[1 ], 0.1 * dxs[1 ], dx2s[1 ])
260
+ @test psout[2 ] == dummy_update_rule (stin[2 ]. state[2 ], psin[2 ], 0.1 * dxs[2 ], dx2s[2 ])
261
+ end
262
+
223
263
end
224
264
@testset verbose= true " Destructure" begin
225
265
include (" destructure.jl" )
0 commit comments