@@ -47,9 +47,6 @@ function setup(rule::Optimisers.AbstractRule, model)
47
47
state
48
48
end
49
49
50
- # opt = Flux.setup(Adam(), model); train!(model, opt) do m ...
51
- setup (model, rule:: Optimisers.AbstractRule ) = setup (rule, model)
52
-
53
50
"""
54
51
train!(loss, model, data, opt)
55
52
@@ -112,56 +109,10 @@ function train!(loss, model, data, opt)
112
109
return losses # Not entirely sure returning losses is a good idea
113
110
end
114
111
115
- """
116
- train!(loss, model, opt)
117
-
118
- Uses a `loss` function improve the `model`'s parameters.
119
-
120
- While the 4-argument method of `train!` iterates over a dataset,
121
- this 3-argument method is for a single datapoint, and calls `gradient` just once.
122
- It expects a function `loss` which takes just one argument, the model.
123
- For example:
124
- ```
125
- opt = Flux.setup(Adam(), model) # explicit setup
126
- train!(model, opt) do m # the model is passed to the function as `m`
127
- Flux.crossentropy(m(x1), y1) # but the data point `(x1, y1)` is closed over.
128
- end
129
- ```
130
- This calls `Zygote.withgradient(m -> Flux.crossentropy(m(x1), y1), model)`.
131
- (The `do` block is another syntax for this anonymous function.)
132
- Then it updates the parameters contained within `model` according to `opt`.
133
- Finally it returns the value of the loss function.
134
-
135
- To iterate over a dataset, writing a loop allows more control than
136
- calling 4-argument `train!`. For example, this adds printing and an early stop:
137
- ```
138
- data = Flux.DataLoader((Xtrain, Ytrain), batchsize=32)
139
- opt = Flux.setup(Adam(), model)
140
- for (i, d) in enumerate(data)
141
- x, y = d
142
- ell = Flux.train!(m -> Flux.crossentropy(m(x), y), model, opt)
143
- i%10==0 && println("on step \$ i, the loss was \$ ell") # prints every 10th step
144
- ell<0.1 && break # stops training
145
- end
146
- ```
147
-
148
- !!! note
149
- This method has no implicit `Params` analog in Flux ≤ 0.13.
150
- """
151
- function train! (loss, model, opt)
152
- l, (g, _... ) = explicit_withgradient (loss, model)
153
- isfinite (l) || return l
154
- _, model = Optimisers. update! (opt, model, g)
155
- return l
156
- end
157
-
158
- # These methods let you use Optimisers.Descent() without setup, when there is no state
112
+ # This method let you use Optimisers.Descent() without setup, when there is no state
159
113
function train! (loss, model, data, rule:: Optimisers.AbstractRule )
160
114
train! (loss, model, data, _rule_to_state (model, rule))
161
115
end
162
- function train! (loss, model, rule:: Optimisers.AbstractRule )
163
- train! (loss, model, _rule_to_state (model, rule))
164
- end
165
116
166
117
function _rule_to_state (model, rule:: Optimisers.AbstractRule )
167
118
state = setup (rule, model)
0 commit comments