Skip to content

Commit 95b0bb2

Browse files
Merge pull request #1785 from FluxML/logankilpatrick-patch-6
Update train.jl to add a more detailed `train!` docstring
2 parents 1a01df2 + ac20b6c commit 95b0bb2

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

src/optimise/train.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,21 @@ batchmemaybe(x::Tuple) = x
8282

8383
"""
8484
train!(loss, params, data, opt; cb)
85-
85+
86+
`train!` uses a `loss` function and training `data` to improve the
87+
[Model parameters](@ref) (`params`) based on a pluggable [Optimisers](@ref) (`opt`).
88+
8689
For each datapoint `d` in `data`, compute the gradient of `loss` with
8790
respect to `params` through backpropagation and call the optimizer `opt`.
88-
8991
If `d` is a tuple of arguments to `loss` call `loss(d...)`, else call `loss(d)`.
90-
91-
A callback is given with the keyword argument `cb`. For example, this will print
92-
"training" every 10 seconds (using [`Flux.throttle`](@ref)):
93-
train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10))
94-
92+
93+
To pass trainable parameters, call [`Flux.params`](@ref) with your model or just the
94+
layers you want to train, like `train!(loss, params(model), ...)` or `train!(loss, params(model[1:end-2), ...)` respectively.
95+
96+
[Callbacks](@ref) are given with the keyword argument `cb`. For example, this will print "training"
97+
every 10 seconds (using [`Flux.throttle`](@ref)):
98+
`train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10))`
99+
95100
The callback can call [`Flux.stop`](@ref) to interrupt the training loop.
96101
97102
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.

0 commit comments

Comments
 (0)