Skip to content

Commit 25eea17

Browse files
authored
Merge branch 'master' into linear-regression
2 parents 6b64b58 + 8d948e8 commit 25eea17

File tree

10 files changed

+368
-38
lines changed

10 files changed

+368
-38
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## v0.13.7
44
* Added [`@autosize` macro](https://github.com/FluxML/Flux.jl/pull/2078)
5+
* New method of `train!` using Zygote's "explicit" mode. Part of a move away from "implicit" `Params`.
56

67
## v0.13.4
78
* Added [`PairwiseFusion` layer](https://github.com/FluxML/Flux.jl/pull/1983)

Project.toml

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
name = "Flux"
22
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
3-
version = "0.13.6"
3+
version = "0.13.8"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
7-
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
87
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
98
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
109
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
@@ -27,16 +26,15 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2726

2827
[compat]
2928
Adapt = "3.0"
30-
ArrayInterface = "3.1, 4, 5, 6"
3129
CUDA = "3"
3230
ChainRulesCore = "1.12"
3331
Functors = "0.3"
34-
MLUtils = "0.2"
32+
MLUtils = "0.2, 0.3.1"
3533
MacroTools = "0.5"
3634
NNlib = "0.8.9"
3735
NNlibCUDA = "0.2.4"
38-
OneHotArrays = "0.1"
39-
Optimisers = "0.2.1"
36+
OneHotArrays = "0.1, 0.2"
37+
Optimisers = "0.2.10"
4038
ProgressLogging = "0.1"
4139
Reexport = "0.2, 1.0"
4240
SpecialFunctions = "1.8.2, 2.1.2"

docs/src/models/overview.md

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,15 @@ julia> predict(x_train)
7777
In order to make better predictions, you'll need to provide a *loss function* to tell Flux how to objectively *evaluate* the quality of a prediction. Loss functions compute the cumulative distance between actual values and predictions.
7878

7979
```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
80-
julia> loss(x, y) = Flux.Losses.mse(predict(x), y);
80+
julia> using Statistics
8181
82-
julia> loss(x_train, y_train)
82+
julia> loss(model, x, y) = mean(abs2.(model(x) .- y));
83+
84+
julia> loss(predict, x_train, y_train)
8385
122.64734f0
8486
```
8587

86-
More accurate predictions will yield a lower loss. You can write your own loss functions or rely on those already provided by Flux. This loss function is called [mean squared error](https://www.statisticshowto.com/probability-and-statistics/statistics-definitions/mean-squared-error/). Flux works by iteratively reducing the loss through *training*.
88+
More accurate predictions will yield a lower loss. You can write your own loss functions or rely on those already provided by Flux. This loss function is called [mean squared error](https://www.statisticshowto.com/probability-and-statistics/statistics-definitions/mean-squared-error/) (and built-in as [`mse`](@ref Flux.Losses.mse)). Flux works by iteratively reducing the loss through *training*.
8789

8890
## 3. Improve the Prediction
8991

@@ -112,40 +114,28 @@ julia> predict.bias
112114
0.0
113115
```
114116

115-
The dimensions of these model parameters depend on the number of inputs and outputs. Since models can have hundreds of inputs and several layers, it helps to have a function to collect the parameters into the data structure Flux expects:
116-
117-
```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
118-
julia> parameters = Flux.params(predict)
119-
Params([Float32[0.9066542], Float32[0.0]])
120-
```
121-
122-
These are the parameters Flux will change, one step at a time, to improve predictions. At each step, the contents of this `Params` object changes too, since it is just a collection of references to the mutable arrays inside the model:
123-
124-
```jldoctest overview
125-
julia> predict.weight in parameters, predict.bias in parameters
126-
(true, true)
127-
```
117+
The dimensions of these model parameters depend on the number of inputs and outputs.
128118

129-
The first parameter is the weight and the second is the bias. Flux will adjust predictions by iteratively changing these parameters according to the optimizer.
119+
Flux will adjust predictions by iteratively changing these parameters according to the optimizer.
130120

131121
This optimiser implements the classic gradient descent strategy. Now improve the parameters of the model with a call to [`Flux.train!`](@ref) like this:
132122

133123
```jldoctest overview
134-
julia> train!(loss, parameters, data, opt)
124+
julia> train!(loss, predict, data, opt)
135125
```
136126

137127
And check the loss:
138128

139129
```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
140-
julia> loss(x_train, y_train)
130+
julia> loss(predict, x_train, y_train)
141131
116.38745f0
142132
```
143133

144134
It went down. Why?
145135

146136
```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
147-
julia> parameters
148-
Params([Float32[7.5777884], Float32[1.9466728]])
137+
julia> predict.weight, predict.bias
138+
(Float32[7.5777884], Float32[1.9466728])
149139
```
150140

151141
The parameters have changed. This single step is the essence of machine learning.
@@ -156,14 +146,14 @@ In the previous section, we made a single call to `train!` which iterates over t
156146

157147
```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
158148
julia> for epoch in 1:200
159-
train!(loss, parameters, data, opt)
149+
train!(loss, predict, data, opt)
160150
end
161151
162-
julia> loss(x_train, y_train)
152+
julia> loss(predict, x_train, y_train)
163153
0.00339581f0
164154
165-
julia> parameters
166-
Params([Float32[4.0178537], Float32[2.0050256]])
155+
julia> predict.weight, predict.bias
156+
(Float32[4.0178537], Float32[2.0050256])
167157
```
168158

169159
After 200 training steps, the loss went down, and the parameters are getting close to those in the function the model is built to predict.
@@ -188,7 +178,7 @@ First, we gathered real-world data into the variables `x_train`, `y_train`, `x_t
188178

189179
Then, we built a single input, single output predictive model, `predict = Dense(1 => 1)`. The initial predictions weren't accurate, because we had not trained the model yet.
190180

191-
After building the model, we trained it with `train!(loss, parameters, data, opt)`. The loss function is first, followed by the `parameters` holding the weights and biases of the model, the training data, and the `Descent` optimizer provided by Flux. We ran the training step once, and observed that the parameters changed and the loss went down. Then, we ran the `train!` many times to finish the training process.
181+
After building the model, we trained it with `train!(loss, predict, data, opt)`. The loss function is first, followed by the model itself, the training data, and the `Descent` optimizer provided by Flux. We ran the training step once, and observed that the parameters changed and the loss went down. Then, we ran the `train!` many times to finish the training process.
192182

193183
After we trained the model, we verified it with the test data to verify the results.
194184

src/Flux.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ export Descent, Adam, Momentum, Nesterov, RMSProp,
3434
AdamW, RAdam, AdaBelief, InvDecay, ExpDecay,
3535
WeightDecay, ClipValue, ClipNorm
3636

37+
include("train.jl")
38+
using .Train
39+
# using .Train: setup, @train_autodiff
40+
3741
using CUDA
3842
const use_cuda = Ref{Union{Nothing,Bool}}(nothing)
3943

src/deprecations.jl

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,105 @@ Base.@deprecate_binding ADAGrad AdaGrad
8282
Base.@deprecate_binding ADADelta AdaDelta
8383

8484
@deprecate rng_from_array() default_rng_value()
85+
86+
#=
87+
# Valid method in Optimise, old implicit style, is:
88+
train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
89+
90+
# Valid methods in Train, new explict style, are:
91+
train!(loss, model, data, opt) # preferred
92+
train!(loss, model, data, opt::Optimisers.AbstractRule) # if you forget setup
93+
94+
# Provide friendly errors for what happens if you mix these up:
95+
=#
96+
import .Optimise: train!
97+
98+
train!(loss, ps::Params, data, opt) = error(
99+
"""can't mix implict Params with explict state!
100+
To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-module.
101+
But better to use the new explicit style, in which `m` itself is the 2nd argument.
102+
""")
103+
104+
train!(loss, ps::Params, data, opt::Optimisers.AbstractRule) = error(
105+
"""can't mix implict Params with explict rule from Optimisers.jl
106+
To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-module.
107+
But better to use the new explicit style, in which `m` itself is the 2nd argument.
108+
""")
109+
110+
train!(loss, model, data, opt::Optimise.AbstractOptimiser) = train!(loss, model, data, _old_to_new(opt))
111+
112+
# Next, to use the new `setup` with the still-exported old-style `Adam` etc:
113+
import .Train: setup
114+
setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(rule), model)
115+
# ... and allow accidental use of `Optimisers.setup` to do the same:
116+
Optimisers.setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(rule), model)
117+
118+
for T in [:Descent, :Adam, :Momentum, :Nesterov,
119+
:AdaGrad, :AdaMax, :AdaDelta, :AMSGrad, :NAdam, :RAdam, :OAdam, :AdaBelief,
120+
# :InvDecay, :ExpDecay,
121+
]
122+
@eval function _old_to_new(rule::$T)
123+
args = map(f -> getfield(rule, f), fieldnames(Optimisers.$T))
124+
Optimisers.$T(args...)
125+
end
126+
end
127+
_old_to_new(rule::Optimiser) = Optimisers.OptimiserChain(map(_old_to_new, rule.os)...)
128+
const OptimiserChain = Optimise.Optimiser # lets you use new name with implicit params too.
129+
_old_to_new(rule::WeightDecay) = Optimisers.WeightDecay(rule.wd) # called gamma now
130+
_old_to_new(rule::ClipNorm) = Optimisers.ClipNorm(rule.thesh) # called omega, and there are more fields
131+
_old_to_new(rule::ClipValue) = Optimisers.ClipGrad(rule.thesh) # called delta now, and struct name differs
132+
const ClipGrad = Optimise.ClipValue
133+
_old_to_new(rule::RMSProp) = Optimisers.RMSProp(rule.eta, rule.rho, rule.epsilon) # RMSProp has no field centred
134+
135+
_old_to_new(rule) = error("Flux.setup does not know how to translate this old-style implicit rule to a new-style Optimisers.jl explicit rule")
136+
137+
# Since `update!` should be called in a loop, it makes less sense to call `setup` for you if you forgot.
138+
# But let's make sure that such uses give a helpful error:
139+
import .Optimise: update!
140+
141+
function update!(opt::Optimise.AbstractOptimiser, model, grad)
142+
# This error method requires narrowing the main worker method of Flux.Optimise
143+
# to accept only arrays. Remove if this causes problems!
144+
# update!(opt::Flux.Optimise.AbstractOptimiser, x::AbstractArray, x̄)
145+
error("""Invalid input to `update!`.
146+
* For the implicit style, this needs `update(::AbstractOptimiser, ::Params, ::Grads)`
147+
* For the explicit style, `update(state, model, grad)` needs `state = Flux.setup(opt, model)`.
148+
""")
149+
end
150+
151+
# An easy error to make is to pass result of explicit gradient(...), not gradient(...)[1]
152+
# Can't catch every case, but can catch many simple Flux models:
153+
154+
function update!(opt, model::Chain, grads::Tuple)
155+
# Zygote will make a NamedTuple{(:layers,)} for the gradient of Chain, Diffractor a Tangent
156+
@warn """explicit `update!(opt, model, grad)` wants the gradient for the model alone,
157+
not the whole tuple from `gradient(m -> loss(m, x, y), model)`. You probably want `grads[1]`."""
158+
update!(opt, model, grads[1])
159+
end
160+
161+
function update!(opt::Optimise.AbstractOptimiser, model::Chain, grads::Tuple) # ambiguity
162+
update!(opt, model, grads[1]) # calls error case "Invalid input" just above
163+
end
164+
165+
# One more easy error to catch is using explicit gradient with `params(m)`:
166+
167+
function update!(opt::Optimise.AbstractOptimiser, ::Params, grads::Union{Tuple, NamedTuple})
168+
error("""can't mix implicit Params with explicit gradients!
169+
* For the implicit style, this needs `update(::AbstractOptimiser, ::Params, ::Grads)` with implicit gradient.
170+
* For the explicit style, `update(state, model, grad)` needs the model itself, and `state = Flux.setup(opt, model)`.
171+
""")
172+
end
173+
174+
# v0.14 deprecations
175+
176+
# Enable these when 0.14 is released, and delete const ClipGrad = Optimise.ClipValue etc:
177+
# Base.@deprecate_binding Optimiser OptimiserChain
178+
# Base.@deprecate_binding ClipValue ClipGrad
179+
180+
# train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError(
181+
# """On Flux 0.14, `train!` no longer accepts implicit `Zygote.Params`.
182+
# Instead of `train!(loss_xy, Flux.params(model), data, Adam())`
183+
# it now needs `opt = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt)`
184+
# where `loss_mxy` accepts the model as its first argument.
185+
# """
186+
# ))

src/optimise/Optimise.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
module Optimise
22

33
using LinearAlgebra
4-
import ArrayInterface
54

65
export train!, update!,
76
Descent, Adam, Momentum, Nesterov, RMSProp,

src/optimise/train.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,27 @@
11
using ProgressLogging: @progress, @withprogress, @logprogress
22
import Zygote: Params, gradient, withgradient
33

4+
# Add methods to Optimisers.jl's function, so that there is just one Flux.update!
5+
# for both explicit and implicit parameters.
6+
import Optimisers.update!
47

58
"""
69
update!(opt, p, g)
710
update!(opt, ps::Params, gs)
811
912
Perform an update step of the parameters `ps` (or the single parameter `p`)
10-
according to optimizer `opt` and the gradients `gs` (the gradient `g`).
13+
according to optimizer `opt::AbstractOptimiser` and the gradients `gs` (the gradient `g`).
1114
1215
As a result, the parameters are mutated and the optimizer's internal state may change.
1316
The gradient could be mutated as well.
17+
18+
!!! note
19+
This method for implicit `Params` (and `AbstractOptimiser`) will be removed from Flux 0.14.
20+
The explicit method `update!(opt, model, grad)` from Optimisers.jl will remain.
1421
"""
15-
function update!(opt::AbstractOptimiser, x, x̄)
16-
x̄r = ArrayInterface.restructure(x, x̄) # address some cases where Zygote's
17-
# output are not mutable, see #1510
22+
function update!(opt::AbstractOptimiser, x::AbstractArray, x̄)
23+
x̄r = copyto!(similar(x̄), x̄) # Flux.Optimise assumes it can mutate the gradient. This is not
24+
# safe due to aliasing, nor guaranteed to be possible, e.g. Fill.
1825
x .-= apply!(opt, x, x̄r)
1926
end
2027

@@ -88,6 +95,10 @@ batchmemaybe(x::Tuple) = x
8895
Uses a `loss` function and training `data` to improve the
8996
model's parameters according to a particular optimisation rule `opt`.
9097
98+
!!! note
99+
This method with implicit `Params` will be removed from Flux 0.14.
100+
It should be replaced with the explicit method `train!(loss, model, data, opt)`.
101+
91102
For each `d in data`, first the gradient of the `loss` is computed like this:
92103
```
93104
gradient(() -> loss(d...), pars) # if d isa Tuple

0 commit comments

Comments
 (0)