1
- using ProgressLogging: @progress , @withprogress , @logprogress
2
- import Zygote: Params, gradient, withgradient
3
-
4
-
5
1
"""
6
2
update!(opt, p, g)
7
3
update!(opt, ps::Params, gs)
@@ -12,18 +8,23 @@ according to optimizer `opt` and the gradients `gs` (the gradient `g`).
12
8
As a result, the parameters are mutated and the optimizer's internal state may change.
13
9
The gradient could be mutated as well.
14
10
"""
15
- function update! (opt:: AbstractOptimiser , x, x̄)
11
+ function Optimisers . update! (opt:: AbstractOptimiser , x, x̄)
16
12
x̄r = ArrayInterface. restructure (x, x̄) # address some cases where Zygote's
17
13
# output are not mutable, see #1510
18
14
x .- = apply! (opt, x, x̄r)
15
+
16
+ return opt, x
19
17
end
20
18
21
- function update! (opt:: AbstractOptimiser , xs:: Params , gs)
19
+ function Optimisers . update! (opt:: AbstractOptimiser , xs:: Params , gs)
22
20
for x in xs
23
21
isnothing (gs[x]) && continue
24
22
update! (opt, x, gs[x])
25
23
end
24
+
25
+ return opt, xs
26
26
end
27
+ Optimisers. update (opt:: AbstractOptimiser , xs:: Params , gs) = update! (opt, xs, gs)
27
28
28
29
# Callback niceties
29
30
call (f, xs... ) = f (xs... )
82
83
batchmemaybe (x) = tuple (x)
83
84
batchmemaybe (x:: Tuple ) = x
84
85
86
+ _build_loss (:: AD.AbstractBackend , loss, data) = function _loss (m)
87
+ return loss (m, data... )
88
+ end
89
+ _build_loss (:: ZygoteImplicitBackend , loss, data) = function _loss ()
90
+ return loss (data... )
91
+ end
92
+ _gradient_only (x:: Zygote.Grads ) = x
93
+ _gradient_only (x:: NTuple{1} ) = x[1 ]
94
+ _gradient_only (x) = error (" Expected gradient w.r.t. single argument (or Zygote.Grads) but got $x " )
95
+
85
96
"""
86
97
train!(loss, pars::Params, data, opt::AbstractOptimiser; [cb])
87
98
@@ -122,19 +133,18 @@ The callback can call [`Flux.stop`](@ref) to interrupt the training loop.
122
133
123
134
Multiple callbacks can be passed to `cb` as array.
124
135
"""
125
- function train! (loss, ps :: Params , data, opt :: AbstractOptimiser ; cb = () -> ())
136
+ function train! (loss, ad :: AD.AbstractBackend , model, data, optstate ; cb = () -> ())
126
137
cb = runall (cb)
127
138
itrsz = Base. IteratorSize (typeof (data))
128
139
n = (itrsz == Base. HasLength ()) || (itrsz == Base. HasShape {1} ()) ? length (data) : 0
129
140
@withprogress for (i, d) in enumerate (data)
130
141
try
131
- l, gs = withgradient (ps) do
132
- loss (batchmemaybe (d)... )
133
- end
142
+ _loss = _build_loss (ad, loss, batchmemaybe (d))
143
+ l, gs = AD. valud_and_gradient (ad, _loss, model)
134
144
if ! isfinite (l)
135
145
throw (DomainError (" Loss is $l on data item $i , stopping training" ))
136
146
end
137
- update! (opt, ps, gs )
147
+ optstate, model = update (optstate, model, _gradient_only (gs) )
138
148
cb ()
139
149
catch ex
140
150
if ex isa StopException
@@ -147,7 +157,11 @@ function train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
147
157
end
148
158
@logprogress iszero (n) ? nothing : i / n
149
159
end
160
+
161
+ return optstate, model
150
162
end
163
+ train! (loss, model, data, optstate; kwargs... ) =
164
+ train! (loss, ZygoteImplicitBackend (), model, data, optstate; kwargs... )
151
165
152
166
"""
153
167
@epochs N body
0 commit comments