@@ -82,16 +82,21 @@ batchmemaybe(x::Tuple) = x
82
82
83
83
"""
84
84
train!(loss, params, data, opt; cb)
85
-
85
+
86
+ `train!` uses a `loss` function and training `data` to improve the
87
+ [Model parameters](@ref) (`params`) based on a pluggable [Optimisers](@ref) (`opt`).
88
+
86
89
For each datapoint `d` in `data`, compute the gradient of `loss` with
87
90
respect to `params` through backpropagation and call the optimizer `opt`.
88
-
89
91
If `d` is a tuple of arguments to `loss` call `loss(d...)`, else call `loss(d)`.
90
-
91
- A callback is given with the keyword argument `cb`. For example, this will print
92
- "training" every 10 seconds (using [`Flux.throttle`](@ref)):
93
- train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10))
94
-
92
+
93
+ To pass trainable parameters, call [`Flux.params`](@ref) with your model or just the
94
+ layers you want to train, like `train!(loss, params(model), ...)` or `train!(loss, params(model[1:end-2), ...)` respectively.
95
+
96
+ [Callbacks](@ref) are given with the keyword argument `cb`. For example, this will print "training"
97
+ every 10 seconds (using [`Flux.throttle`](@ref)):
98
+ `train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10))`
99
+
95
100
The callback can call [`Flux.stop`](@ref) to interrupt the training loop.
96
101
97
102
Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
0 commit comments