Skip to content

Commit 3bca907

Browse files
committed
group the tests
1 parent d13e52a commit 3bca907

File tree

1 file changed

+51
-44
lines changed

1 file changed

+51
-44
lines changed

test/runtests.jl

Lines changed: 51 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ using Optimisers: @.., @lazy
55

66
Random.seed!(1)
77

8+
# Fake "models" for testing
9+
810
struct Foo; x; y; end
911
Functors.@functor Foo
1012
Optimisers.trainable(x::Foo) = (x.y, x.x)
@@ -16,6 +18,8 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
1618
mutable struct MutTwo; x; y; end
1719
Functors.@functor MutTwo
1820

21+
# Simple rules for testing
22+
1923
struct DummyHigherOrder <: AbstractRule end
2024
Optimisers.init(::DummyHigherOrder, x::AbstractArray) =
2125
(ones(eltype(x), size(x)), zero(x))
@@ -227,23 +231,6 @@ end
227231
@test_throws MethodError Optimisers.update(sm, m)
228232
end
229233

230-
@testset "2nd order gradient" begin
231-
m == ([1.0], sin), γ = Float32[4,3,2])
232-
233-
# Special rule which requires this:
234-
s = Optimisers.setup(BiRule(), m)
235-
g == ([0.1], ZeroTangent()), γ = [1,10,100],)
236-
s1, m1 = Optimisers.update(s, m, g, g)
237-
@test m1.α[1] == [0.9]
238-
@test_throws Exception Optimisers.update(s, m, g, map(x->2 .* x, g))
239-
240-
# Ordinary rule which doesn't need it:
241-
s2 = Optimisers.setup(Adam(), m)
242-
s3, m3 = Optimisers.update(s2, m, g)
243-
s4, m4 = Optimisers.update(s2, m, g, g)
244-
@test m3.γ == m4.γ
245-
end
246-
247234
@testset "broadcasting macros" begin
248235
x = [1.0, 2.0]; y = [3,4]; z = [5,6]
249236
@test (@lazy x + y * z) isa Broadcast.Broadcasted
@@ -365,34 +352,54 @@ end
365352
@test model2.a === model2.b # tie of MutTwo structs is restored
366353
@test model2.a !== model2.c # but a new tie is not created
367354
end
368-
end
355+
end # tied weights
356+
357+
@testset "2nd-order interface" begin
358+
@testset "BiRule" begin
359+
m == ([1.0], sin), γ = Float32[4,3,2])
360+
361+
# Special rule which requires this:
362+
s = Optimisers.setup(BiRule(), m)
363+
g == ([0.1], ZeroTangent()), γ = [1,10,100],)
364+
s1, m1 = Optimisers.update(s, m, g, g)
365+
@test m1.α[1] == [0.9]
366+
@test_throws Exception Optimisers.update(s, m, g, map(x->2 .* x, g))
367+
368+
# Ordinary rule which doesn't need it:
369+
s2 = Optimisers.setup(Adam(), m)
370+
s3, m3 = Optimisers.update(s2, m, g)
371+
s4, m4 = Optimisers.update(s2, m, g, g)
372+
@test m3.γ == m4.γ
373+
end
369374

370-
@testset "higher order interface" begin
371-
w, b = rand(3, 4), rand(3)
372-
373-
o = DummyHigherOrder()
374-
psin = (w, b)
375-
dxs = map(x -> rand(size(x)...), psin)
376-
dx2s = map(x -> rand(size(x)...), psin)
377-
stin = Optimisers.setup(o, psin)
378-
stout, psout = Optimisers.update(stin, psin, dxs, dx2s)
379-
380-
# hardcoded rule behavior for dummy rule
381-
@test psout[1] == dummy_update_rule(stin[1].state, psin[1], dxs[1], dx2s[1])
382-
@test psout[2] == dummy_update_rule(stin[2].state, psin[2], dxs[2], dx2s[2])
383-
@test stout[1].state[1] == stin[1].state[1] .+ 1
384-
@test stout[2].state[2] == stin[2].state[2] .+ 1
385-
386-
# error if only given one derivative
387-
@test_throws MethodError Optimisers.update(stin, psin, dxs)
388-
389-
# first-order rules compose with second-order
390-
ochain = OptimiserChain(Descent(0.1), o)
391-
stin = Optimisers.setup(ochain, psin)
392-
stout, psout = Optimisers.update(stin, psin, dxs, dx2s)
393-
@test psout[1] == dummy_update_rule(stin[1].state[2], psin[1], 0.1 * dxs[1], dx2s[1])
394-
@test psout[2] == dummy_update_rule(stin[2].state[2], psin[2], 0.1 * dxs[2], dx2s[2])
395-
end
375+
@testset "DummyHigherOrder" begin
376+
w, b = rand(3, 4), rand(3)
377+
378+
o = DummyHigherOrder()
379+
psin = (w, b)
380+
dxs = map(x -> rand(size(x)...), psin)
381+
dx2s = map(x -> rand(size(x)...), psin)
382+
stin = Optimisers.setup(o, psin)
383+
stout, psout = Optimisers.update(stin, psin, dxs, dx2s)
384+
385+
# hardcoded rule behavior for dummy rule
386+
@test psout[1] == dummy_update_rule(stin[1].state, psin[1], dxs[1], dx2s[1])
387+
@test psout[2] == dummy_update_rule(stin[2].state, psin[2], dxs[2], dx2s[2])
388+
@test stout[1].state[1] == stin[1].state[1] .+ 1
389+
@test stout[2].state[2] == stin[2].state[2] .+ 1
390+
391+
# error if only given one derivative
392+
@test_throws MethodError Optimisers.update(stin, psin, dxs)
393+
394+
# first-order rules compose with second-order
395+
ochain = OptimiserChain(Descent(0.1), o)
396+
stin = Optimisers.setup(ochain, psin)
397+
stout, psout = Optimisers.update(stin, psin, dxs, dx2s)
398+
@test psout[1] == dummy_update_rule(stin[1].state[2], psin[1], 0.1 * dxs[1], dx2s[1])
399+
@test psout[2] == dummy_update_rule(stin[2].state[2], psin[2], 0.1 * dxs[2], dx2s[2])
400+
end
401+
end # 2nd-order
402+
end
396403

397404
end
398405
@testset verbose=true "Destructure" begin

0 commit comments

Comments
 (0)