Skip to content

Commit 4e937df

Browse files
committed
remove 3-argument train! since this requires impure loss function, and you can just use update! instead really.
1 parent fa022b3 commit 4e937df

File tree

3 files changed

+1
-74
lines changed

3 files changed

+1
-74
lines changed

src/deprecations.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,22 +89,14 @@ Base.@deprecate_binding ADADelta AdaDelta
8989
# Valid methods in Train, new explict style, are:
9090
train!(loss, model, data, opt)
9191
train!(loss, model, data, opt::Optimisers.AbstractRule)
92-
# ... and 3-arg:
93-
train!(loss, model, opt)
94-
train!(loss, model, opt::Optimisers.AbstractRule)
9592
# Provide friendly errors for what happens if you mix these up:
9693
=#
9794
import .Optimise: train!
9895
train!(loss, ps::Params, data, opt) = error("can't mix implict Params with explict state")
99-
train!(loss, ps::Params, opt) = error("can't mix implict Params with explict state")
10096

10197
train!(loss, ps::Params, data, opt::Optimisers.AbstractRule) = error("can't mix implict Params with explict rule")
102-
train!(loss, ps::Params, opt::Optimisers.AbstractRule) = error("can't mix implict Params with explict rule")
10398

10499
train!(loss, model, data, opt::Optimise.AbstractOptimiser) = train!(loss, model, data, _old_to_new(opt))
105-
train!(loss, model, opt::Optimise.AbstractOptimiser) = train!(loss, model, _old_to_new(opt))
106-
107-
train!(loss, ps::Params, opt::Optimise.AbstractOptimiser; cb=0) = error("3-arg train does not exist for implicit mode")
108100

109101
# train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError(
110102
# """On Flux 0.14, `train!` no longer accepts implicit `Zygote.Params`.

src/train.jl

Lines changed: 1 addition & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,6 @@ function setup(rule::Optimisers.AbstractRule, model)
4747
state
4848
end
4949

50-
# opt = Flux.setup(Adam(), model); train!(model, opt) do m ...
51-
setup(model, rule::Optimisers.AbstractRule) = setup(rule, model)
52-
5350
"""
5451
train!(loss, model, data, opt)
5552
@@ -112,56 +109,10 @@ function train!(loss, model, data, opt)
112109
return losses # Not entirely sure returning losses is a good idea
113110
end
114111

115-
"""
116-
train!(loss, model, opt)
117-
118-
Uses a `loss` function improve the `model`'s parameters.
119-
120-
While the 4-argument method of `train!` iterates over a dataset,
121-
this 3-argument method is for a single datapoint, and calls `gradient` just once.
122-
It expects a function `loss` which takes just one argument, the model.
123-
For example:
124-
```
125-
opt = Flux.setup(Adam(), model) # explicit setup
126-
train!(model, opt) do m # the model is passed to the function as `m`
127-
Flux.crossentropy(m(x1), y1) # but the data point `(x1, y1)` is closed over.
128-
end
129-
```
130-
This calls `Zygote.withgradient(m -> Flux.crossentropy(m(x1), y1), model)`.
131-
(The `do` block is another syntax for this anonymous function.)
132-
Then it updates the parameters contained within `model` according to `opt`.
133-
Finally it returns the value of the loss function.
134-
135-
To iterate over a dataset, writing a loop allows more control than
136-
calling 4-argument `train!`. For example, this adds printing and an early stop:
137-
```
138-
data = Flux.DataLoader((Xtrain, Ytrain), batchsize=32)
139-
opt = Flux.setup(Adam(), model)
140-
for (i, d) in enumerate(data)
141-
x, y = d
142-
ell = Flux.train!(m -> Flux.crossentropy(m(x), y), model, opt)
143-
i%10==0 && println("on step \$i, the loss was \$ell") # prints every 10th step
144-
ell<0.1 && break # stops training
145-
end
146-
```
147-
148-
!!! note
149-
This method has no implicit `Params` analog in Flux ≤ 0.13.
150-
"""
151-
function train!(loss, model, opt)
152-
l, (g, _...) = explicit_withgradient(loss, model)
153-
isfinite(l) || return l
154-
_, model = Optimisers.update!(opt, model, g)
155-
return l
156-
end
157-
158-
# These methods let you use Optimisers.Descent() without setup, when there is no state
112+
# This method let you use Optimisers.Descent() without setup, when there is no state
159113
function train!(loss, model, data, rule::Optimisers.AbstractRule)
160114
train!(loss, model, data, _rule_to_state(model, rule))
161115
end
162-
function train!(loss, model, rule::Optimisers.AbstractRule)
163-
train!(loss, model, _rule_to_state(model, rule))
164-
end
165116

166117
function _rule_to_state(model, rule::Optimisers.AbstractRule)
167118
state = setup(rule, model)

test/train.jl

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,6 @@ using Random
2222
@test loss(model, rand(10, 10)) < 0.01
2323
end
2424

25-
# Test 3-arg `Flux.train!` method:
26-
@testset for rule in [Descent(0.1), Adam(), AdamW()]
27-
28-
loss(m) = let x = rand(10)
29-
Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
30-
end
31-
model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
32-
@test loss(model) > 1
33-
34-
opt = Flux.setup(rule, model)
35-
for i in 1:10^5
36-
Flux.train!(loss, model, opt)
37-
end
38-
@test loss(model) < 0.01
39-
end
40-
4125
# Test direct use of Optimisers.jl rule, only really OK for `Descent`:
4226
@testset "without setup, $opt" for opt in [Descent(0.1), Optimisers.Descent(0.1), Optimisers.Adam()]
4327
loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias)

0 commit comments

Comments
 (0)