1
1
using ProgressLogging: @progress , @withprogress , @logprogress
2
2
import Zygote: Params, gradient
3
3
4
- """
5
- update!(x, x̄)
6
-
7
- Update the array `x` according to `x .-= x̄`.
8
- """
9
- function update! (x:: AbstractArray , x̄)
10
- x .- = x̄
11
- end
12
4
13
5
"""
14
6
update!(opt, p, g)
@@ -20,13 +12,13 @@ according to optimizer `opt` and the gradients `gs` (the gradient `g`).
20
12
As a result, the parameters are mutated and the optimizer's internal state may change.
21
13
The gradient could be mutated as well.
22
14
"""
23
- function update! (opt, x, x̄)
15
+ function update! (opt:: AbstractOptimiser , x, x̄)
24
16
x̄r = ArrayInterface. restructure (x, x̄) # address some cases where Zygote's
25
17
# output are not mutable, see #1510
26
18
x .- = apply! (opt, x, x̄r)
27
19
end
28
20
29
- function update! (opt, xs:: Params , gs)
21
+ function update! (opt:: AbstractOptimiser , xs:: Params , gs)
30
22
for x in xs
31
23
isnothing (gs[x]) && continue
32
24
update! (opt, x, gs[x])
@@ -81,28 +73,44 @@ batchmemaybe(x) = tuple(x)
81
73
batchmemaybe (x:: Tuple ) = x
82
74
83
75
"""
84
- train!(loss, params, data, opt; cb)
85
-
86
- `train!` uses a `loss` function and training `data` to improve the
87
- [Model parameters](@ref) (`params`) based on a pluggable [Optimisers](@ref) (`opt`).
76
+ train!(loss, pars::Params, data, opt::AbstractOptimiser; [cb])
88
77
89
- For each datapoint `d` in `data`, compute the gradient of `loss` with
90
- respect to `params` through backpropagation and call the optimizer `opt`.
91
- If `d` is a tuple of arguments to `loss` call `loss(d...)`, else call `loss(d)`.
92
-
93
- To pass trainable parameters, call [`Flux.params`](@ref) with your model or just the
94
- layers you want to train, like `train!(loss, params(model), ...)` or `train!(loss, params(model[1:end-2), ...)` respectively.
78
+ Uses a `loss` function and training `data` to improve the
79
+ model's parameters according to a particular optimisation rule `opt`.
95
80
96
- [Callbacks](@ref) are given with the keyword argument `cb`. For example, this will print "training"
97
- every 10 seconds (using [`Flux.throttle`](@ref)):
98
- `train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10))`
99
-
81
+ For each `d in data`, first the gradient of the `loss` is computed like this:
82
+ ```
83
+ gradient(() -> loss(d...), pars) # if d isa Tuple
84
+ gradient(() -> loss(d), pars) # otherwise
85
+ ```
86
+ Here `pars` is produced by calling [`Flux.params`](@ref) on your model.
87
+ (Or just on the layers you want to train, like `train!(loss, params(model[1:end-2]), data, opt)`.)
88
+ This is the "implicit" style of parameter handling.
89
+
90
+ Then, this gradient is used by optimizer `opt` to update the paramters:
91
+ ```
92
+ update!(opt, pars, grads)
93
+ ```
94
+ The optimiser should be from the [Flux.Optimise](@ref) module.
95
+ Different optimisers can be combined using [Flux.Optimise.Optimiser](@ref).
96
+
97
+ This training loop iterates through `data` once.
98
+ You can use [`@epochs`](@ref) to do this several times, or
99
+ use for instance `Iterators.repeat` to make a longer `data` iterator.
100
+
101
+ ## Callbacks
102
+
103
+ [Callbacks](@ref) are given with the keyword argument `cb`.
104
+ For example, this will print "training" every 10 seconds (using [`Flux.throttle`](@ref)):
105
+ ```
106
+ train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10))
107
+ ```
108
+
100
109
The callback can call [`Flux.stop`](@ref) to interrupt the training loop.
101
110
102
- Multiple optimisers and callbacks can be passed to `opt` and ` cb` as arrays .
111
+ Multiple callbacks can be passed to `cb` as array .
103
112
"""
104
- function train! (loss, ps, data, opt; cb = () -> ())
105
- ps = Params (ps)
113
+ function train! (loss, ps:: Params , data, opt:: AbstractOptimiser ; cb = () -> ())
106
114
cb = runall (cb)
107
115
n = (Base. IteratorSize (typeof (data)) == Base. HasLength ()) ? length (data) : 0
108
116
@withprogress for (i, d) in enumerate (data)
0 commit comments