1
- using ProgressLogging: @progress , @withprogress , @logprogress
2
- import Zygote: Params, gradient
3
-
4
-
5
1
"""
6
2
update!(opt, p, g)
7
3
update!(opt, ps::Params, gs)
@@ -12,18 +8,23 @@ according to optimizer `opt` and the gradients `gs` (the gradient `g`).
12
8
As a result, the parameters are mutated and the optimizer's internal state may change.
13
9
The gradient could be mutated as well.
14
10
"""
15
- function update! (opt:: AbstractOptimiser , x, x̄)
11
+ function Optimisers . update! (opt:: AbstractOptimiser , x, x̄)
16
12
x̄r = ArrayInterface. restructure (x, x̄) # address some cases where Zygote's
17
13
# output are not mutable, see #1510
18
14
x .- = apply! (opt, x, x̄r)
15
+
16
+ return opt, x
19
17
end
20
18
21
- function update! (opt:: AbstractOptimiser , xs:: Params , gs)
19
+ function Optimisers . update! (opt:: AbstractOptimiser , xs:: Params , gs)
22
20
for x in xs
23
21
isnothing (gs[x]) && continue
24
22
update! (opt, x, gs[x])
25
23
end
24
+
25
+ return opt, xs
26
26
end
27
+ Optimisers. update (opt:: AbstractOptimiser , xs:: Params , gs) = update! (opt, xs, gs)
27
28
28
29
# Callback niceties
29
30
call (f, xs... ) = f (xs... )
82
83
batchmemaybe (x) = tuple (x)
83
84
batchmemaybe (x:: Tuple ) = x
84
85
86
+ _build_loss (:: AD.AbstractBackend , loss, data) = function _loss (m)
87
+ return loss (m, data... )
88
+ end
89
+ _build_loss (:: ZygoteImplicitBackend , loss, data) = function _loss ()
90
+ return loss (data... )
91
+ end
92
+ _gradient_only (x:: Zygote.Grads ) = x
93
+ _gradient_only (x:: NTuple{1} ) = x[1 ]
94
+ _gradient_only (x) = error (" Expected gradient w.r.t. single argument (or Zygote.Grads) but got $x " )
95
+
85
96
"""
86
97
train!(loss, pars::Params, data, opt::AbstractOptimiser; [cb])
87
98
@@ -120,16 +131,15 @@ The callback can call [`Flux.stop`](@ref) to interrupt the training loop.
120
131
121
132
Multiple callbacks can be passed to `cb` as array.
122
133
"""
123
- function train! (loss, ps :: Params , data, opt :: AbstractOptimiser ; cb = () -> ())
134
+ function train! (loss, ad :: AD.AbstractBackend , model, data, optstate ; cb = () -> ())
124
135
cb = runall (cb)
125
136
itrsz = Base. IteratorSize (typeof (data))
126
137
n = (itrsz == Base. HasLength ()) || (itrsz == Base. HasShape {1} ()) ? length (data) : 0
127
138
@withprogress for (i, d) in enumerate (data)
128
139
try
129
- gs = gradient (ps) do
130
- loss (batchmemaybe (d)... )
131
- end
132
- update! (opt, ps, gs)
140
+ _loss = _build_loss (ad, loss, batchmemaybe (d))
141
+ gs = _gradient_only (AD. gradient (ad, _loss, model))
142
+ optstate, model = update (optstate, model, gs)
133
143
cb ()
134
144
catch ex
135
145
if ex isa StopException
@@ -142,7 +152,11 @@ function train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
142
152
end
143
153
@logprogress iszero (n) ? nothing : i / n
144
154
end
155
+
156
+ return optstate, model
145
157
end
158
+ train! (loss, model, data, optstate; kwargs... ) =
159
+ train! (loss, ZygoteImplicitBackend (), model, data, optstate; kwargs... )
146
160
147
161
"""
148
162
@epochs N body
0 commit comments