Skip to content

Commit cb927a8

Browse files
committed
tests, II
1 parent 01c7af1 commit cb927a8

File tree

2 files changed

+41
-8
lines changed

2 files changed

+41
-8
lines changed

src/train/explicit_train.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ function train!(loss::Function, model, data, opt::FluxState)
5454
s = opt.state
5555
s isa IdDict && error("""Can't mix explicit & implicit modes!
5656
Once `FluxState` is initialised by `train!` in one mode, it cannot be used in the other.""")
57+
# TODO check whether this loop ought to be in another function, for perfomance / type-stability.
5758
for d in data
5859
l, (g, _...) = explicit_withgradient(loss, model, data_splat(d)...)
5960
s, model = Optimisers.update!(s, model, g)
@@ -110,11 +111,13 @@ function train!(loss::Function, model, opt::FluxState)
110111
end
111112

112113
# This method lets you use Optimisers.Descent() instead of Flux.Descent(), when there is no state
113-
function train!(loss::Function, model, data, opt::Optimisers.AbstractRule)
114+
function train!(loss::Function, model, data, rule::Optimisers.AbstractRule)
115+
opt = FluxState(rule, missing)
114116
_initialise!(opt, model)
115-
fmap(opt.state, exclude = x -> x isa Optimsers.Leaf) do leaf
116-
leaf.state isa Nothing || @warn "Optimiser state will be lost! Please wrap optimisation rule in `FluxState`, e.g. by using `Flux.Adam()`" leaf
117+
@gensym warn_id
118+
fmap(opt.state, exclude = x -> x isa Optimisers.Leaf) do leaf
119+
leaf.state isa Nothing || @warn "Optimiser state will be discarded! Please wrap optimisation rule from Optimisers.jl in `FluxState`, e.g. by using `Flux.Adam()`" leaf maxlog=1 _id=warn_id
117120
leaf
118121
end
119-
train!(loss, model, data, FluxState(opt))
122+
train!(loss, model, data, opt)
120123
end

test/train.jl

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Flux.Train
22
using Zygote: Params, gradient
33

4-
import FillArrays, ComponentArrays
4+
import Optimisers, FillArrays, ComponentArrays
55

66
using Test
77
using Random
@@ -29,10 +29,40 @@ end
2929
w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset.
3030
@testset for opt in [Descent(0.1), Adam()]
3131
@test opt isa FluxState
32-
w′ = copy(w2)
33-
b = zeros(10)
32+
@test opt.state isa Missing
33+
34+
loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
35+
model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
36+
@test loss(model, rand(10, 10)) > 1
37+
38+
train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
39+
@test loss(model, rand(10, 10)) < 0.01
40+
@test opt.state isa NamedTuple
41+
end
42+
43+
# Test 3-arg `train!` method:
44+
@testset for opt in [Descent(0.1), Adam()]
45+
@test opt isa FluxState
46+
@test opt.state isa Missing
47+
48+
loss(m) = let x = rand(10)
49+
Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
50+
end
51+
model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
52+
@test loss(model) > 1
53+
54+
for i in 1:10^5
55+
train!(loss, model, opt)
56+
end
57+
@test loss(model) < 0.01
58+
@test opt.state isa NamedTuple
59+
end
60+
61+
# Test direct use of Optimisers.jl rule, only really OK for `Descent`:
62+
@testset for opt in [Optimisers.Descent(0.1), Optimisers.Adam()]
63+
@test opt isa Optimisers.AbstractRule
3464
loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
35-
model = (weight=w′, bias=b, ignore=nothing)
65+
model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
3666
@test loss(model, rand(10, 10)) > 1
3767
train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
3868
@test loss(model, rand(10, 10)) < 0.01

0 commit comments

Comments
 (0)