Skip to content

Commit 056ae81

Browse files
committed
make it stricter, to avoid batchmaybe weirdness
1 parent bbc0f85 commit 056ae81

File tree

2 files changed

+21
-18
lines changed

2 files changed

+21
-18
lines changed

src/train.jl

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,17 @@ according to a particular optimisation rule `opt`.
6363
* Instead of `loss` being a function which typically accepts two arguments
6464
(the input `x` and expected output `y` from each element of `data`)
6565
now it should typically accept three, the first of which is the `model` itself.
66-
* `data` should iterate tuples or NamedTuples
67-
* `opt` should be the result of [`Flux.setup`](@ref).
66+
* `data` must iterate tuples. Each `d in data` is used as `loss(model, d...)`.
67+
* `opt` should be the result of [`Flux.setup`](@ref), it will warn you if not.
6868
* Callback functions are not supported.
6969
7070
For example, with these definitions...
7171
```
72-
data = [(x1, y1), (x2, y2), (x3, y3)]; # each element must be a tuple (or NamedTuple)
72+
data = [(x1, y1), (x2, y2), (x3, y3)]; # each element must be a tuple
7373
74-
loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument
74+
loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument
7575
76-
opt = Flux.setup(Adam(), model) # explicit setup of optimiser momenta
76+
opt = Flux.setup(Adam(), model) # explicit setup of optimiser momenta
7777
```
7878
...calling `train!(loss3, model, data, opt)` runs a loop much like this:
7979
```
@@ -82,19 +82,28 @@ for d in data
8282
Optimisers.update!(opt, model, ∂L∂m)
8383
end
8484
```
85-
Stops with a `DomainError` if the loss is infinite or `NaN` at any point.
85+
You can also write this loop yourself, if you need more flexibility.
86+
Besides the loop, `train!` will:
8687
87-
Returns a vector containing the value of the loss function at each datapoint.
88+
* Stop with a `DomainError` if the loss is infinite or `NaN` at any point.
8889
89-
The built-in loss functions accept 3 arguments, allowing for instance `train!(Flux.Losses.mse, model, data, opt)`.
90+
* Return a vector containing the value of the loss function at each datapoint.
9091
91-
Callback functions are not supported. But see 3-argument `train!(loss, model, opt)` for an
92-
easy way to construct more complicated training loops.
92+
* Show a progress bar using [`@withprogress`](https://github.com/JuliaLogging/ProgressLogging.jl).
93+
94+
Note that the built-in loss functions accept 3 arguments, allowing for instance
95+
`train!(Flux.Losses.mse, model, data, opt)` instead of defining `loss3` as above.
96+
97+
Note that callback functions are not supported. But arbitrary code can be inserted into the loop.
9398
"""
9499
function train!(loss, model, data, opt)
100+
Base.issingletontype(typeof(loss)) || error("""train! with explicit parameter expects a pure loss function.
101+
It must not close over the model, like loss(x,y) = mse(model(x), y). """)
95102
losses = Float32[]
96103
@withprogress for (i,d) in enumerate(data)
97-
l, (g, _...) = explicit_withgradient(loss, model, data_splat(d)...)
104+
d isa Tuple || error("""train! expects as data an iterator producing tuples, but got $(typeof(d)).
105+
Pass it `((d,) for d in data)`, or use `gradient` and `update!` for more control.""")
106+
l, (g, _...) = explicit_withgradient(loss, model, d...)
98107
isfinite(l) || throw(DomainError("loss function returned $l, stopping training"))
99108
opt, model = Optimisers.update!(opt, model, g)
100109
push!(losses, l)
@@ -103,12 +112,6 @@ function train!(loss, model, data, opt)
103112
return losses # Not entirely sure returning losses is a good idea
104113
end
105114

106-
data_splat(x::T) where T = error("""train! expects every d in data be a Tuple or a NamedTuple, got $T
107-
To allow this type, define `Flux.Train.data_splat(x::$T) = (x,)`""")
108-
data_splat(x::Tuple) = x
109-
data_splat(x::NamedTuple) = x
110-
data_splat(x::AbstractArray{<:Number}) = (x,)
111-
112115
"""
113116
train!(loss, model, opt)
114117

test/train.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ using Random
4949
end
5050

5151
@testset "Explicit Flux.train! features" begin
52-
# Test that splat accepts NamedTuple
52+
# Test errors from wrong kind of iterator
5353
# Test NaN / Inf early stop
5454
# Test that loss is returned
5555
end

0 commit comments

Comments
 (0)