Skip to content

Commit ed78e8a

Browse files
authored
RFC: Restrict train! to AbstractOptimiser (#1902)
1 parent b6dbefb commit ed78e8a

File tree

6 files changed

+49
-35
lines changed

6 files changed

+49
-35
lines changed

.github/workflows/Downstream.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
name: Downstream
2-
32
on:
43
push:
54
branches: [master]
@@ -10,6 +9,8 @@ jobs:
109
test:
1110
name: ${{ matrix.package.repo }}/${{ matrix.package.group }}
1211
runs-on: ${{ matrix.os }}
12+
env:
13+
GROUP: ${{ matrix.package.group }}
1314
strategy:
1415
fail-fast: false
1516
matrix:

src/deprecations.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ struct Zeros
3434
end
3535
Zeros(args...) = Zeros() # was used both Dense(10, 2, initb = Zeros) and Dense(rand(2,10), Zeros())
3636

37+
function Optimise.update!(x::AbstractArray, x̄)
38+
depwarn("`Flux.Optimise.update!(x, x̄)` was not used internally and has been removed. Please write `x .-= x̄` instead.", :update!)
39+
x .-=
40+
end
41+
3742
# Channel notation: Changed to match Conv, but very softly deprecated!
3843
# Perhaps change to @deprecate for v0.14, but there is no plan to remove these.
3944
Dense(in::Integer, out::Integer, σ = identity; kw...) =

src/optimise/optimisers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ opt = AdaBelief()
521521
opt = AdaBelief(0.001, (0.9, 0.8))
522522
```
523523
"""
524-
mutable struct AdaBelief
524+
mutable struct AdaBelief <: AbstractOptimiser
525525
eta::Float64
526526
beta::Tuple{Float64,Float64}
527527
epsilon::Float64
@@ -553,7 +553,7 @@ mutable struct Optimiser <: AbstractOptimiser
553553
os::Vector{Any}
554554
end
555555

556-
Optimiser(o...) = Optimiser(Any[o...])
556+
Optimiser(opts::AbstractOptimiser...) = Optimiser(Any[opts...])
557557

558558
@forward Optimiser.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex!
559559
@forward Optimiser.os Base.iterate

src/optimise/train.jl

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,6 @@
11
using ProgressLogging: @progress, @withprogress, @logprogress
22
import Zygote: Params, gradient
33

4-
"""
5-
update!(x, x̄)
6-
7-
Update the array `x` according to `x .-= x̄`.
8-
"""
9-
function update!(x::AbstractArray, x̄)
10-
x .-=
11-
end
124

135
"""
146
update!(opt, p, g)
@@ -20,13 +12,13 @@ according to optimizer `opt` and the gradients `gs` (the gradient `g`).
2012
As a result, the parameters are mutated and the optimizer's internal state may change.
2113
The gradient could be mutated as well.
2214
"""
23-
function update!(opt, x, x̄)
15+
function update!(opt::AbstractOptimiser, x, x̄)
2416
x̄r = ArrayInterface.restructure(x, x̄) # address some cases where Zygote's
2517
# output are not mutable, see #1510
2618
x .-= apply!(opt, x, x̄r)
2719
end
2820

29-
function update!(opt, xs::Params, gs)
21+
function update!(opt::AbstractOptimiser, xs::Params, gs)
3022
for x in xs
3123
isnothing(gs[x]) && continue
3224
update!(opt, x, gs[x])
@@ -81,28 +73,44 @@ batchmemaybe(x) = tuple(x)
8173
batchmemaybe(x::Tuple) = x
8274

8375
"""
84-
train!(loss, params, data, opt; cb)
85-
86-
`train!` uses a `loss` function and training `data` to improve the
87-
[Model parameters](@ref) (`params`) based on a pluggable [Optimisers](@ref) (`opt`).
76+
train!(loss, pars::Params, data, opt::AbstractOptimiser; [cb])
8877
89-
For each datapoint `d` in `data`, compute the gradient of `loss` with
90-
respect to `params` through backpropagation and call the optimizer `opt`.
91-
If `d` is a tuple of arguments to `loss` call `loss(d...)`, else call `loss(d)`.
92-
93-
To pass trainable parameters, call [`Flux.params`](@ref) with your model or just the
94-
layers you want to train, like `train!(loss, params(model), ...)` or `train!(loss, params(model[1:end-2), ...)` respectively.
78+
Uses a `loss` function and training `data` to improve the
79+
model's parameters according to a particular optimisation rule `opt`.
9580
96-
[Callbacks](@ref) are given with the keyword argument `cb`. For example, this will print "training"
97-
every 10 seconds (using [`Flux.throttle`](@ref)):
98-
`train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10))`
99-
81+
For each `d in data`, first the gradient of the `loss` is computed like this:
82+
```
83+
gradient(() -> loss(d...), pars) # if d isa Tuple
84+
gradient(() -> loss(d), pars) # otherwise
85+
```
86+
Here `pars` is produced by calling [`Flux.params`](@ref) on your model.
87+
(Or just on the layers you want to train, like `train!(loss, params(model[1:end-2]), data, opt)`.)
88+
This is the "implicit" style of parameter handling.
89+
90+
Then, this gradient is used by optimizer `opt` to update the paramters:
91+
```
92+
update!(opt, pars, grads)
93+
```
94+
The optimiser should be from the [Flux.Optimise](@ref) module.
95+
Different optimisers can be combined using [Flux.Optimise.Optimiser](@ref).
96+
97+
This training loop iterates through `data` once.
98+
You can use [`@epochs`](@ref) to do this several times, or
99+
use for instance `Iterators.repeat` to make a longer `data` iterator.
100+
101+
## Callbacks
102+
103+
[Callbacks](@ref) are given with the keyword argument `cb`.
104+
For example, this will print "training" every 10 seconds (using [`Flux.throttle`](@ref)):
105+
```
106+
train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10))
107+
```
108+
100109
The callback can call [`Flux.stop`](@ref) to interrupt the training loop.
101110
102-
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
111+
Multiple callbacks can be passed to `cb` as array.
103112
"""
104-
function train!(loss, ps, data, opt; cb = () -> ())
105-
ps = Params(ps)
113+
function train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
106114
cb = runall(cb)
107115
n = (Base.IteratorSize(typeof(data)) == Base.HasLength()) ? length(data) : 0
108116
@withprogress for (i, d) in enumerate(data)

test/data.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ using Random
7676
X = zeros(2, 10)
7777
loss(x) = sum((x .- θ).^2)
7878
d = DataLoader(X)
79-
Flux.train!(loss, [θ], ncycle(d, 10), Descent(0.1))
79+
Flux.train!(loss, Params([θ]), ncycle(d, 10), Descent(0.1))
8080
@test norm(θ) < 1e-4
8181

8282
# test interaction with `train!`
@@ -85,7 +85,7 @@ using Random
8585
Y = fill(2, 10)
8686
loss(x, y) = sum((y - x'*θ).^2)
8787
d = DataLoader((X, Y))
88-
Flux.train!(loss, [θ], ncycle(d, 10), Descent(0.1))
88+
Flux.train!(loss, Params([θ]), ncycle(d, 10), Descent(0.1))
8989
@test norm.- 1) < 1e-10
9090

9191
# specify the rng

test/optimise.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ end
5050
l = 1
5151
Flux.train!(
5252
() -> (sleep(0.1); Flux.skip(); i+=1),
53-
(),
53+
Params([]),
5454
Iterators.repeated((), 10),
5555
Descent()
5656
)
@@ -59,7 +59,7 @@ end
5959

6060
Flux.train!(
6161
() -> (sleep(0.1); i==8 && Flux.skip(); i+=1),
62-
(),
62+
Params([]),
6363
Iterators.repeated((), 10),
6464
Descent()
6565
)
@@ -68,7 +68,7 @@ end
6868

6969
i = 0
7070
Flux.train!(() -> (sleep(0.1); i += 1; l),
71-
(),
71+
Params([]),
7272
Iterators.repeated((), 100),
7373
Descent(),
7474
cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1))

0 commit comments

Comments
 (0)