@@ -43,16 +43,60 @@ function Base.show(io::IO, opt::FluxState)
43
43
end
44
44
end
45
45
46
+ _DESCENT_EXAMPLE = """ # Implicit-style example
47
+ This usage matches Flux ≤ v0.13:
48
+ ```
49
+ opt = Flux.Descent(0.3)
50
+
51
+ ps = Flux.params(model) # returns a Zygote.Params object
52
+
53
+ gs = gradient(ps) do # gradient takes a zero-argument anonymous function
54
+ loss3(model, x, y) # ... which depends on the global model
55
+ end # ... and returns a Zygote.Grads object
56
+
57
+ Flux.update!(opt, ps, gs)
58
+ ```
59
+ New on Flux v0.14 is a method `train!(loss, ps, opt)` which performs one step,
60
+ rather than iterating over `data`. This is equivalent to `gradient` and `update!` above:
61
+ ```
62
+ Flux.train!(ps, opt) do
63
+ loss3(model, x, y)
64
+ end
65
+ ```
66
+
67
+ # Explicit-style example
68
+
69
+ This no longer uses `Flux.params`, but instead the model itself:
70
+ ```
71
+ opt = Flux.Descent(0.3) # the same FluxState object
72
+
73
+ Flux.train!(model, opt) do m # now explicitly depends on the model
74
+ loss3(m, x, y)
75
+ end
76
+ ```
77
+ """
46
78
for opt in [
47
79
:Descent , :Adam , :Momentum , :Nesterov , :RMSProp ,
48
80
:AdaGrad , :AdaMax , :AdaDelta , :AMSGrad , :NAdam , :AdamW , :RAdam , :OAdam , :AdaBelief ,
49
- # :InvDecay, :ExpDecay, :WeightDecay, :stop, :skip, : Optimiser,
50
- # :ClipValue , :ClipNorm,
51
- # TODO check that parameters line up nicely old-vs-new, and include the remaining rules
81
+ # :InvDecay, :ExpDecay, :WeightDecay, :Optimiser,
82
+ :ClipGrad , :ClipNorm ,
83
+ # TODO sort out the remaining rules
52
84
]
53
- @eval $ opt (parameters... ; kw... ) = FluxState (Optimisers.$ opt (parameters... ; kw... ), missing )
85
+ @eval begin
86
+ $ opt (parameters... ; kw... ) = FluxState (Optimisers.$ opt (parameters... ; kw... ), missing )
87
+ str = string (""" Flux.$($ opt) (args...)
88
+
89
+ Returns `FluxState` wrapper around the following rule definition from Optimisers.jl,
90
+ allowing its use with `Flux.train!` (in the same manner as `Flux.AbstractOptimiser` objects on Flux ≤ v0.13).
91
+ Accepts the same arguments, with the same defaults, as the underlying rule:
92
+
93
+ """ , @doc (Optimisers.$ opt), $ opt == Descent ? _DESCENT_EXAMPLE : " " )
94
+ @doc str $ opt
95
+ end
54
96
end
55
97
98
+ @deprecate ClipValue ClipGrad
99
+
56
100
57
101
# ## Two styles of gradient, and their `train!` functions
58
102
0 commit comments