Skip to content

Commit 7aa74e0

Browse files
committed
wrap more rules, test & doc
1 parent cb927a8 commit 7aa74e0

File tree

2 files changed

+55
-10
lines changed

2 files changed

+55
-10
lines changed

src/train/Train.jl

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,60 @@ function Base.show(io::IO, opt::FluxState)
4343
end
4444
end
4545

46+
_DESCENT_EXAMPLE = """# Implicit-style example
47+
This usage matches Flux ≤ v0.13:
48+
```
49+
opt = Flux.Descent(0.3)
50+
51+
ps = Flux.params(model) # returns a Zygote.Params object
52+
53+
gs = gradient(ps) do # gradient takes a zero-argument anonymous function
54+
loss3(model, x, y) # ... which depends on the global model
55+
end # ... and returns a Zygote.Grads object
56+
57+
Flux.update!(opt, ps, gs)
58+
```
59+
New on Flux v0.14 is a method `train!(loss, ps, opt)` which performs one step,
60+
rather than iterating over `data`. This is equivalent to `gradient` and `update!` above:
61+
```
62+
Flux.train!(ps, opt) do
63+
loss3(model, x, y)
64+
end
65+
```
66+
67+
# Explicit-style example
68+
69+
This no longer uses `Flux.params`, but instead the model itself:
70+
```
71+
opt = Flux.Descent(0.3) # the same FluxState object
72+
73+
Flux.train!(model, opt) do m # now explicitly depends on the model
74+
loss3(m, x, y)
75+
end
76+
```
77+
"""
4678
for opt in [
4779
:Descent, :Adam, :Momentum, :Nesterov, :RMSProp,
4880
:AdaGrad, :AdaMax, :AdaDelta, :AMSGrad, :NAdam, :AdamW, :RAdam, :OAdam, :AdaBelief,
49-
# :InvDecay, :ExpDecay, :WeightDecay, :stop, :skip, :Optimiser,
50-
# :ClipValue, :ClipNorm,
51-
# TODO check that parameters line up nicely old-vs-new, and include the remaining rules
81+
# :InvDecay, :ExpDecay, :WeightDecay, :Optimiser,
82+
:ClipGrad, :ClipNorm,
83+
# TODO sort out the remaining rules
5284
]
53-
@eval $opt(parameters...; kw...) = FluxState(Optimisers.$opt(parameters...; kw...), missing)
85+
@eval begin
86+
$opt(parameters...; kw...) = FluxState(Optimisers.$opt(parameters...; kw...), missing)
87+
str = string(""" Flux.$($opt)(args...)
88+
89+
Returns `FluxState` wrapper around the following rule definition from Optimisers.jl,
90+
allowing its use with `Flux.train!` (in the same manner as `Flux.AbstractOptimiser` objects on Flux ≤ v0.13).
91+
Accepts the same arguments, with the same defaults, as the underlying rule:
92+
93+
""", @doc(Optimisers.$opt), $opt == Descent ? _DESCENT_EXAMPLE : "")
94+
@doc str $opt
95+
end
5496
end
5597

98+
@deprecate ClipValue ClipGrad
99+
56100

57101
### Two styles of gradient, and their `train!` functions
58102

test/train.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@ using Random
1010
Random.seed!(84)
1111
w = randn(10, 10)
1212
w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset.
13-
@testset for opt in [Descent(0.1), Adam()]
14-
# [AdamW(), AdaGrad(0.1), AdaMax(), AdaDelta(0.9), AMSGrad(),
15-
# NAdam(), RAdam(), Descent(0.1), Adam(), OAdam(), AdaBelief(),
16-
# Nesterov(), RMSProp(), Momentum()]
13+
@testset for opt in [AdamW(), AdaGrad(0.1), AdaMax(), AdaDelta(0.9), AMSGrad(),
14+
NAdam(), RAdam(), Descent(0.1), Adam(), OAdam(), AdaBelief(),
15+
Nesterov(), RMSProp(), Momentum()]
1716
w′ = copy(w2)
1817
b = zeros(10)
1918
loss(x) = Flux.Losses.mse(w*x, w′*x .+ b)
@@ -27,7 +26,9 @@ end
2726
Random.seed!(84)
2827
w = randn(10, 10)
2928
w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset.
30-
@testset for opt in [Descent(0.1), Adam()]
29+
@testset for opt in [AdamW(), AdaGrad(0.1), AdaMax(), AdaDelta(0.9), AMSGrad(),
30+
NAdam(), RAdam(), Descent(0.1), Adam(), OAdam(), AdaBelief(),
31+
Nesterov(), RMSProp(), Momentum()]
3132
@test opt isa FluxState
3233
@test opt.state isa Missing
3334

@@ -41,7 +42,7 @@ end
4142
end
4243

4344
# Test 3-arg `train!` method:
44-
@testset for opt in [Descent(0.1), Adam()]
45+
@testset for opt in [Descent(0.1), Adam(), AdamW()]
4546
@test opt isa FluxState
4647
@test opt.state isa Missing
4748

0 commit comments

Comments
 (0)