@@ -63,17 +63,17 @@ according to a particular optimisation rule `opt`.
63
63
* Instead of `loss` being a function which typically accepts two arguments
64
64
(the input `x` and expected output `y` from each element of `data`)
65
65
now it should typically accept three, the first of which is the `model` itself.
66
- * `data` should iterate tuples or NamedTuples
67
- * `opt` should be the result of [`Flux.setup`](@ref).
66
+ * `data` must iterate tuples. Each `d in data` is used as `loss(model, d...)`.
67
+ * `opt` should be the result of [`Flux.setup`](@ref), it will warn you if not .
68
68
* Callback functions are not supported.
69
69
70
70
For example, with these definitions...
71
71
```
72
- data = [(x1, y1), (x2, y2), (x3, y3)]; # each element must be a tuple (or NamedTuple)
72
+ data = [(x1, y1), (x2, y2), (x3, y3)]; # each element must be a tuple
73
73
74
- loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument
74
+ loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument
75
75
76
- opt = Flux.setup(Adam(), model) # explicit setup of optimiser momenta
76
+ opt = Flux.setup(Adam(), model) # explicit setup of optimiser momenta
77
77
```
78
78
...calling `train!(loss3, model, data, opt)` runs a loop much like this:
79
79
```
@@ -82,19 +82,28 @@ for d in data
82
82
Optimisers.update!(opt, model, ∂L∂m)
83
83
end
84
84
```
85
- Stops with a `DomainError` if the loss is infinite or `NaN` at any point.
85
+ You can also write this loop yourself, if you need more flexibility.
86
+ Besides the loop, `train!` will:
86
87
87
- Returns a vector containing the value of the loss function at each datapoint .
88
+ * Stop with a `DomainError` if the loss is infinite or `NaN` at any point .
88
89
89
- The built-in loss functions accept 3 arguments, allowing for instance `train!(Flux.Losses.mse, model, data, opt)` .
90
+ * Return a vector containing the value of the loss function at each datapoint .
90
91
91
- Callback functions are not supported. But see 3-argument `train!(loss, model, opt)` for an
92
- easy way to construct more complicated training loops.
92
+ * Show a progress bar using [`@withprogress`](https://github.com/JuliaLogging/ProgressLogging.jl).
93
+
94
+ Note that the built-in loss functions accept 3 arguments, allowing for instance
95
+ `train!(Flux.Losses.mse, model, data, opt)` instead of defining `loss3` as above.
96
+
97
+ Note that callback functions are not supported. But arbitrary code can be inserted into the loop.
93
98
"""
94
99
function train! (loss, model, data, opt)
100
+ Base. issingletontype (typeof (loss)) || error (""" train! with explicit parameter expects a pure loss function.
101
+ It must not close over the model, like loss(x,y) = mse(model(x), y). """ )
95
102
losses = Float32[]
96
103
@withprogress for (i,d) in enumerate (data)
97
- l, (g, _... ) = explicit_withgradient (loss, model, data_splat (d)... )
104
+ d isa Tuple || error (""" train! expects as data an iterator producing tuples, but got $(typeof (d)) .
105
+ Pass it `((d,) for d in data)`, or use `gradient` and `update!` for more control.""" )
106
+ l, (g, _... ) = explicit_withgradient (loss, model, d... )
98
107
isfinite (l) || throw (DomainError (" loss function returned $l , stopping training" ))
99
108
opt, model = Optimisers. update! (opt, model, g)
100
109
push! (losses, l)
@@ -103,12 +112,6 @@ function train!(loss, model, data, opt)
103
112
return losses # Not entirely sure returning losses is a good idea
104
113
end
105
114
106
- data_splat (x:: T ) where T = error (""" train! expects every d in data be a Tuple or a NamedTuple, got $T
107
- To allow this type, define `Flux.Train.data_splat(x::$T ) = (x,)`""" )
108
- data_splat (x:: Tuple ) = x
109
- data_splat (x:: NamedTuple ) = x
110
- data_splat (x:: AbstractArray{<:Number} ) = (x,)
111
-
112
115
"""
113
116
train!(loss, model, opt)
114
117
0 commit comments