@@ -4,7 +4,7 @@ using LinearAlgebra
4
4
using Optimisers: Optimisers
5
5
using Functors: fmap
6
6
7
- export train!, update!, adjust!, FluxState, @epochs ,
7
+ export train!, update!, adjust!, FluxState,
8
8
Descent, Adam, Momentum, Nesterov, RMSProp,
9
9
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief # ,
10
10
# InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser,
@@ -15,7 +15,7 @@ export train!, update!, adjust!, FluxState, @epochs,
15
15
16
16
"""
17
17
FluxState(rule, state=missing)
18
-
18
+
19
19
This is an interface between the all-mutable world Flux.jl likes,
20
20
and the could-be-immutable world that Optimisers.jl inhabits.
21
21
56
56
57
57
# ## Two styles of gradient, and their `train!` functions
58
58
59
- using ProgressLogging: @progress , @withprogress , @logprogress
59
+ using ProgressLogging: @progress , @withprogress , @logprogress # TODO add progress logging again
60
60
using Zygote: Zygote, Params
61
61
62
- include (" explicit_train.jl.jl " ) # new!
63
- include (" implicit_train.jl.jl " ) # Params etc, Zygote only
62
+ include (" explicit_train.jl" ) # new!
63
+ include (" implicit_train.jl" ) # Params etc, Zygote only
64
64
65
65
explicit_withgradient (f, args... ) = Zygote. withgradient (f, args... ) # can overload this to use e.g. Yota / Diffractor
66
66
67
- # using Requires # Flux doesn't use this right now
68
- # @init @require Diffractor="9f5e2b26-1114-432f-b630-d3fe2085c51c" begin
69
- # @eval function explicit_withgradient(f, args...)
70
- # y, back = Diffractor.∂⃖¹(f, args...)
71
- # _, grads... = back(Zygote.sensitivity(y))
72
- # return (; value = y, gradient = grads)
73
- # end
74
- # end
75
-
76
- #=
77
-
78
- using Diffractor
79
- function Flux.Train.explicit_withgradient(f, args...)
80
- y, back = Diffractor.∂⃖¹(f, args...)
81
- _, grads... = back(one(y))
82
- return (; value = y, gradient = grads)
83
- end
84
-
85
- =#
86
-
87
67
# ## Misc. related utilities
88
68
89
69
"""
@@ -107,94 +87,4 @@ function adjust!(opt::FluxState, eta::Real)
107
87
return opt
108
88
end
109
89
110
- """
111
- @epochs N body
112
-
113
- Run `body` expression `N` times. Mainly useful for quickly doing
114
- multiple epochs of training in a REPL.
115
-
116
- Functionally equivalent to this loop:
117
- ```
118
- for _ in 1:N
119
- body
120
- end
121
- ```
122
- ... but adds progress logging and `@info` messages,
123
- and returns the result of the last iteration.
124
-
125
- # Examples
126
- ```jldoctest
127
- julia> Flux.@epochs 2 println("hello")
128
- [ Info: Epoch 1
129
- hello
130
- [ Info: Epoch 2
131
- hello
132
- ```
133
- """
134
- macro epochs (n, ex)
135
- @gensym val
136
- body = :(for i in 1 : $ (esc (n))
137
- @info " Epoch $i "
138
- $ (esc (val)) = $ (esc (ex))
139
- end )
140
- loop = Expr (:macrocall , Symbol (" @progress" ), __source__, body)
141
- Expr (:block , :($ (esc (val)) = nothing ), loop, :($ (esc (val))))
142
- # TODO make this actualy return the value? Names aren't right.
143
- #
144
- # $loop
145
- # # @progress for i in 1:$(esc(n))
146
- # # @info "Epoch $i"
147
- # # $(esc(val)) = $(esc(ex))
148
- # # end
149
- # $val # DOESN"T WORK! Expr(:macrocall, ...) ?
150
- # end
151
- end
152
-
153
- end
154
-
155
-
156
- #=
157
-
158
- using Flux, Random
159
- data = [(rand(3,2).*[i,1,20/i], [i i]) for i in 1:50] |> shuffle!;
160
-
161
- # This exact code works on Flux@0.13. There, train! returns nothing:
162
- model2 = Chain(Dense(3 => 7, relu), Dense(7 => 1))
163
- opt2 = Flux.Adam()
164
- Flux.train!(Flux.params(model2), data, opt2) do x, y
165
- Flux.mse(model2(x), y)
166
- end
167
- opt2 # contains an IdDict
168
-
169
- # This is the new "explicit" method of Train
170
- model1 = Chain(Dense(3 => 7, relu), Dense(7 => 1))
171
- opt1 = Flux.Adam()
172
- Flux.train!(model1, data, opt1) do m, x, y
173
- Flux.mse(m(x), y)
174
- end |> sum
175
- opt1 # contains state tree
176
-
177
- # This is new 3-arg train!, one step not an iteration over data:
178
- x1, y1 = data[1]
179
- Flux.train!(model1, opt1) do m
180
- Flux.mse(m(x1), y1)
181
- end
182
-
183
-
184
-
185
-
186
-
187
- julia> using ProgressLogging
188
- julia> @macroexpand1 @loop N body
189
- begin
190
- x = nothing
191
- @progress for i in 1:N
192
- @info "step $i"
193
- x = body
194
- end
195
- x
196
- end
197
-
198
-
199
-
200
- =#
90
+ end # module
0 commit comments