Skip to content

Commit bf54f76

Browse files
authored
Merge pull request #110 from darsnack/higher-order-tests
Add tests for higher order interface
2 parents 5f51632 + 6133591 commit bf54f76

File tree

3 files changed

+43
-1
lines changed

3 files changed

+43
-1
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Manifest.toml
2+
.vscode/

src/interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ function update(tree, x, x̄s...)
5050
end
5151

5252
# default all rules to first order calls
53-
apply!(o, state, x, dx, dxs...) = apply!(o, state, x, dx)
53+
apply!(o, state, x, dx, dx2, dxs...) = apply!(o, state, x, dx)
5454

5555
"""
5656
isnumeric(x) -> Bool

test/runtests.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,19 @@ struct TwoThirds a; b; c; end
1313
Functors.@functor TwoThirds (a, c)
1414
Optimisers.trainable(x::TwoThirds) = (a = x.a,)
1515

16+
struct DummyHigherOrder <: AbstractRule end
17+
18+
Optimisers.init(::DummyHigherOrder, x::AbstractArray) =
19+
(ones(eltype(x), size(x)), zero(x))
20+
21+
dummy_update_rule(st, p, dx, dx2) = @. p - (st[1] * dx + st[2] * dx2)
22+
function Optimisers.apply!(::DummyHigherOrder, state, x, dx, dx2)
23+
a, b = state
24+
@.. dx = a * dx + b * dx2
25+
26+
return (a .+ 1, b .+ 1), dx
27+
end
28+
1629
@testset verbose=true "Optimisers.jl" begin
1730
@testset verbose=true "Features" begin
1831

@@ -220,6 +233,33 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
220233
@test_throws ArgumentError Optimisers.setup(AdamW(), m2)
221234
end
222235

236+
@testset "higher order interface" begin
237+
w, b = rand(3, 4), rand(3)
238+
239+
o = DummyHigherOrder()
240+
psin = (w, b)
241+
dxs = map(x -> rand(size(x)...), psin)
242+
dx2s = map(x -> rand(size(x)...), psin)
243+
stin = Optimisers.setup(o, psin)
244+
stout, psout = Optimisers.update(stin, psin, dxs, dx2s)
245+
246+
# hardcoded rule behavior for dummy rule
247+
@test psout[1] == dummy_update_rule(stin[1].state, psin[1], dxs[1], dx2s[1])
248+
@test psout[2] == dummy_update_rule(stin[2].state, psin[2], dxs[2], dx2s[2])
249+
@test stout[1].state[1] == stin[1].state[1] .+ 1
250+
@test stout[2].state[2] == stin[2].state[2] .+ 1
251+
252+
# error if only given one derivative
253+
@test_throws MethodError Optimisers.update(stin, psin, dxs)
254+
255+
# first-order rules compose with second-order
256+
ochain = OptimiserChain(Descent(0.1), o)
257+
stin = Optimisers.setup(ochain, psin)
258+
stout, psout = Optimisers.update(stin, psin, dxs, dx2s)
259+
@test psout[1] == dummy_update_rule(stin[1].state[2], psin[1], 0.1 * dxs[1], dx2s[1])
260+
@test psout[2] == dummy_update_rule(stin[2].state[2], psin[2], 0.1 * dxs[2], dx2s[2])
261+
end
262+
223263
end
224264
@testset verbose=true "Destructure" begin
225265
include("destructure.jl")

0 commit comments

Comments
 (0)