Skip to content

Commit 989286f

Browse files
committed
tidy, tests
1 parent a68470c commit 989286f

File tree

8 files changed

+127
-381
lines changed

8 files changed

+127
-381
lines changed

src/deprecations.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,6 @@ struct Zeros
3434
end
3535
Zeros(args...) = Zeros() # was used both Dense(10, 2, initb = Zeros) and Dense(rand(2,10), Zeros())
3636

37-
# function Optimise.update!(x::AbstractArray, x̄)
38-
# Base.depwarn("`Flux.Optimise.update!(x, x̄)` was not used internally and has been removed. Please write `x .-= x̄` instead.", :update!)
39-
# x .-= x̄
40-
# end
41-
4237
function Diagonal(size::Integer...; kw...)
4338
Base.depwarn("Flux.Diagonal is now Flux.Scale, and also allows an activation function.", :Diagonal)
4439
Scale(size...; kw...)

src/train/Train.jl

Lines changed: 6 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using LinearAlgebra
44
using Optimisers: Optimisers
55
using Functors: fmap
66

7-
export train!, update!, adjust!, FluxState, @epochs,
7+
export train!, update!, adjust!, FluxState,
88
Descent, Adam, Momentum, Nesterov, RMSProp,
99
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief #,
1010
# InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser,
@@ -15,7 +15,7 @@ export train!, update!, adjust!, FluxState, @epochs,
1515

1616
"""
1717
FluxState(rule, state=missing)
18-
18+
1919
This is an interface between the all-mutable world Flux.jl likes,
2020
and the could-be-immutable world that Optimisers.jl inhabits.
2121
@@ -56,34 +56,14 @@ end
5656

5757
### Two styles of gradient, and their `train!` functions
5858

59-
using ProgressLogging: @progress, @withprogress, @logprogress
59+
using ProgressLogging: @progress, @withprogress, @logprogress # TODO add progress logging again
6060
using Zygote: Zygote, Params
6161

62-
include("explicit_train.jl.jl") # new!
63-
include("implicit_train.jl.jl") # Params etc, Zygote only
62+
include("explicit_train.jl") # new!
63+
include("implicit_train.jl") # Params etc, Zygote only
6464

6565
explicit_withgradient(f, args...) = Zygote.withgradient(f, args...) # can overload this to use e.g. Yota / Diffractor
6666

67-
# using Requires # Flux doesn't use this right now
68-
# @init @require Diffractor="9f5e2b26-1114-432f-b630-d3fe2085c51c" begin
69-
# @eval function explicit_withgradient(f, args...)
70-
# y, back = Diffractor.∂⃖¹(f, args...)
71-
# _, grads... = back(Zygote.sensitivity(y))
72-
# return (; value = y, gradient = grads)
73-
# end
74-
# end
75-
76-
#=
77-
78-
using Diffractor
79-
function Flux.Train.explicit_withgradient(f, args...)
80-
y, back = Diffractor.∂⃖¹(f, args...)
81-
_, grads... = back(one(y))
82-
return (; value = y, gradient = grads)
83-
end
84-
85-
=#
86-
8767
### Misc. related utilities
8868

8969
"""
@@ -107,94 +87,4 @@ function adjust!(opt::FluxState, eta::Real)
10787
return opt
10888
end
10989

110-
"""
111-
@epochs N body
112-
113-
Run `body` expression `N` times. Mainly useful for quickly doing
114-
multiple epochs of training in a REPL.
115-
116-
Functionally equivalent to this loop:
117-
```
118-
for _ in 1:N
119-
body
120-
end
121-
```
122-
... but adds progress logging and `@info` messages,
123-
and returns the result of the last iteration.
124-
125-
# Examples
126-
```jldoctest
127-
julia> Flux.@epochs 2 println("hello")
128-
[ Info: Epoch 1
129-
hello
130-
[ Info: Epoch 2
131-
hello
132-
```
133-
"""
134-
macro epochs(n, ex)
135-
@gensym val
136-
body = :(for i in 1:$(esc(n))
137-
@info "Epoch $i"
138-
$(esc(val)) = $(esc(ex))
139-
end)
140-
loop = Expr(:macrocall, Symbol("@progress"), __source__, body)
141-
Expr(:block, :($(esc(val)) = nothing), loop, :($(esc(val))))
142-
# TODO make this actualy return the value? Names aren't right.
143-
#
144-
# $loop
145-
# # @progress for i in 1:$(esc(n))
146-
# # @info "Epoch $i"
147-
# # $(esc(val)) = $(esc(ex))
148-
# # end
149-
# $val # DOESN"T WORK! Expr(:macrocall, ...) ?
150-
# end
151-
end
152-
153-
end
154-
155-
156-
#=
157-
158-
using Flux, Random
159-
data = [(rand(3,2).*[i,1,20/i], [i i]) for i in 1:50] |> shuffle!;
160-
161-
# This exact code works on Flux@0.13. There, train! returns nothing:
162-
model2 = Chain(Dense(3 => 7, relu), Dense(7 => 1))
163-
opt2 = Flux.Adam()
164-
Flux.train!(Flux.params(model2), data, opt2) do x, y
165-
Flux.mse(model2(x), y)
166-
end
167-
opt2 # contains an IdDict
168-
169-
# This is the new "explicit" method of Train
170-
model1 = Chain(Dense(3 => 7, relu), Dense(7 => 1))
171-
opt1 = Flux.Adam()
172-
Flux.train!(model1, data, opt1) do m, x, y
173-
Flux.mse(m(x), y)
174-
end |> sum
175-
opt1 # contains state tree
176-
177-
# This is new 3-arg train!, one step not an iteration over data:
178-
x1, y1 = data[1]
179-
Flux.train!(model1, opt1) do m
180-
Flux.mse(m(x1), y1)
181-
end
182-
183-
184-
185-
186-
187-
julia> using ProgressLogging
188-
julia> @macroexpand1 @loop N body
189-
begin
190-
x = nothing
191-
@progress for i in 1:N
192-
@info "step $i"
193-
x = body
194-
end
195-
x
196-
end
197-
198-
199-
200-
=#
90+
end # module

src/train/explicit_train.jl

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,26 +52,28 @@ function train!(loss::Function, model, data, opt::FluxState)
5252
_initialise!(opt, model)
5353
losses = Float32[]
5454
s = opt.state
55-
s isa IdDict && error("can't mix explicit & implicit!")
55+
s isa IdDict && error("""Can't mix explicit & implicit modes!
56+
Once `FluxState` is initialised by `train!` in one mode, it cannot be used in the other.""")
5657
for d in data
57-
l, (g, _...) = Zygote.withgradient(loss, model, train_ok(d)...)
58+
l, (g, _...) = explicit_withgradient(loss, model, data_splat(d)...)
5859
s, model = Optimisers.update!(s, model, g)
5960
push!(losses, l)
6061
opt.state = s
6162
end
62-
return losses
63+
return losses # Not entirely sure returning losses is a good idea. Flux 0.13 returns `nothing`.
6364
end
6465

65-
train_ok(x::T) where T = error("""train! expects every d in data be a Tuple or a NamedTuple, got $T
66-
To allow this type, define `Flux.Optimise.train_ok(x::$T) = (x,)`""")
67-
train_ok(x::Tuple) = x
68-
train_ok(x::NamedTuple) = x
66+
data_splat(x::T) where T = error("""train! expects every d in data be a Tuple or a NamedTuple, got $T
67+
To allow this type, define `Flux.Train.data_splat(x::$T) = (x,)`""")
68+
data_splat(x::Tuple) = x
69+
data_splat(x::NamedTuple) = x
6970

7071
function _initialise!(opt::FluxState, model)
7172
if opt.state isa Missing
7273
opt.state = Optimisers.setup(opt.rule, model)
7374
fmap(model, exclude = Optimisers.isnumeric) do x
74-
Optimisers.maywrite(x) || error("model must be fully mutable for train! to work, got $(typeof(x))")
75+
Optimisers.maywrite(x) || error("""model must be fully mutable for train! to work, got x::$(typeof(x))
76+
If `x .+= dx` is in fact ok, define `Optimisers.maywrite(::$(typeof(x))) = true`""")
7577
end
7678
end
7779
opt
@@ -107,12 +109,12 @@ function train!(loss::Function, model, opt::FluxState)
107109
l
108110
end
109111

112+
# This method lets you use Optimisers.Descent() instead of Flux.Descent(), when there is no state
110113
function train!(loss::Function, model, data, opt::Optimisers.AbstractRule)
111114
_initialise!(opt, model)
112-
# fmap(opt.state) do x
113-
# x isa Union{Number, AbstractArray{<:Number}} && @warn "optimiser state will be lost!"
114-
# x
115-
# end # won't work as you need to look inside Leaf for non-nothings.
116-
@warn "optimiser state will be lost!"
115+
fmap(opt.state, exclude = x -> x isa Optimsers.Leaf) do leaf
116+
leaf.state isa Nothing || @warn "Optimiser state will be lost! Please wrap optimisation rule in `FluxState`, e.g. by using `Flux.Adam()`" leaf
117+
leaf
118+
end
117119
train!(loss, model, data, FluxState(opt))
118120
end

src/train/implicit_train.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ function train!(loss::Function, pars::Params, data, opt::FluxState)
2929
losses = Float32[]
3030
for d in data
3131
l, grads = Zygote.withgradient(() -> loss(batchmemaybe(d)...), pars)
32-
update!(opt, pars, grads)
32+
_update!(opt, pars, grads)
3333
push!(losses, l)
3434
end
3535
return losses
@@ -49,7 +49,7 @@ function train!(loss::Function, pars::Params, opt::FluxState)
4949
Explicit parameters are now preferred, see `train!(loss, model, data, opt)`""", :train!, force=true)
5050
_initialise!(opt, pars)
5151
l, grads = Zygote.withgradient(() -> loss(), pars)
52-
update!(opt, pars, grads)
52+
_update!(opt, pars, grads)
5353
return l
5454
end
5555

@@ -68,6 +68,12 @@ Legacy method, mimicking the behaviour of Flux <= 0.13.
6868
"""
6969
function update!(opt::FluxState, xs::Params, gs)
7070
Base.depwarn("Flux.update! is a legacy function", :update!)
71+
_initialise!(opt, xs)
72+
_update!(opt, xs, gs)
73+
end
74+
# This _update! exists only so that train! above gives one depwarn, not two!
75+
# ... and also to call _initialise!
76+
function _update!(opt::FluxState, xs::Params, gs)
7177
for x in xs
7278
isnothing(gs[x]) && continue
7379
update!(opt, x, gs[x])

test/layers/conv.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ end
5555
bias = Conv((2, 2), 1=>3, bias = false);
5656
ip = zeros(Float32, 28,28,1,1)
5757
op = zeros(Float32, 27,27,3,1) .+ 2.f0
58-
opt = Descent()
58+
opt = Flux.Descent()
5959

6060
for _ = 1:10^3
6161
gs = gradient(Flux.params(bias)) do
6262
Flux.Losses.mse(bias(ip), op)
6363
end
64-
Flux.Optimise.update!(opt, params(bias), gs)
64+
Flux.Optimise.update!(opt, Flux.params(bias), gs)
6565
end
6666

6767
@test Flux.Losses.mse(bias(ip), op) 4.f0
@@ -168,7 +168,7 @@ end
168168

169169
x = zeros(Float32, 5, 5, 2, 4)
170170
m = ConvTranspose((3,3), 2=>3)
171-
@test gradient(()->sum(m(x)), params(m)) isa Flux.Zygote.Grads
171+
@test gradient(()->sum(m(x)), Flux.params(m)) isa Flux.Zygote.Grads
172172

173173
# test ConvTranspose supports groups argument
174174
x = randn(Float32, 10, 10, 2, 3)
@@ -178,7 +178,7 @@ end
178178
m2 = ConvTranspose((3,3), 2=>4, groups=2, pad=SamePad())
179179
@test size(m2.weight) == (3,3,2,2)
180180
@test size(m1(x)) == size(m2(x))
181-
@test gradient(()->sum(m2(x)), params(m2)) isa Flux.Zygote.Grads
181+
@test gradient(()->sum(m2(x)), Flux.params(m2)) isa Flux.Zygote.Grads
182182

183183
x = randn(Float32, 10, 2,1)
184184
m = ConvTranspose((3,), 2=>4, pad=SamePad(), groups=2)

0 commit comments

Comments
 (0)