@@ -7,7 +7,7 @@ using ..Flux: Flux # used only in docstring
7
7
import .. Flux. Optimise: train!, update! # during 0.13, we add methods to the old functions
8
8
import Enzyme
9
9
10
- export setup, train!, train_enzyme!
10
+ export setup, train!
11
11
12
12
using ProgressLogging: @progress , @withprogress , @logprogress
13
13
using Zygote: Zygote, Params
@@ -53,6 +53,12 @@ function setup(rule::Optimisers.AbstractRule, model)
53
53
state
54
54
end
55
55
56
+ _make_zero_internal! (x:: AbstractArray ) = fill! (x, 0 )
57
+ _make_zero_internal! (x) = x
58
+ _make_zero! (model) = fmap (_make_zero_internal!, model)
59
+
60
+ _applyloss (loss, model, d... ) = loss (model, d... )
61
+
56
62
"""
57
63
train!(loss, model, data, opt_state)
58
64
@@ -61,6 +67,9 @@ according to a particular optimisation rule encoded in `opt_state`.
61
67
Iterates through `data` once, evaluating for each `d in data` either
62
68
`loss(model, d...)` if `d isa Tuple`, or else `loss(model, d)` for other `d`.
63
69
70
+ If `model` is an Enzyme.Duplicated, gradients will be computed with Enzyme,
71
+ otherwise they will be computed with Zygote.
72
+
64
73
For example, with these definitions...
65
74
```
66
75
data = [(x1, y1), (x2, y2), (x3, y3)]
@@ -101,60 +110,30 @@ function train!(loss, model, data, opt; cb = nothing)
101
110
For more control use a loop with `gradient` and `update!`.""" )
102
111
@withprogress for (i,d) in enumerate (data)
103
112
d_splat = d isa Tuple ? d : (d,)
104
- l, gs = Zygote. withgradient (m -> loss (m, d_splat... ), model)
105
- if ! isfinite (l)
106
- throw (DomainError (lazy " Loss is $l on data item $i, stopping training" ))
107
- end
108
- opt, model = Optimisers. update! (opt, model, gs[1 ])
109
- @logprogress Base. haslength (data) ? i/ length (data) : nothing
110
- end
111
- end
113
+
114
+ if model isa Enzyme. Duplicated
115
+ _make_zero! (model. dval)
116
+ _, l = Enzyme. autodiff (Enzyme. ReverseWithPrimal, _applyloss, Enzyme. Active, Enzyme. Const (loss), model, map (Enzyme. Const, d_splat)... )
112
117
113
- _make_zero_internal! (x:: AbstractArray ) = fill! (x, 0 )
114
- _make_zero_internal! (x) = x
115
- _make_zero! (model) = fmap (_make_zero_internal!, model)
118
+ if ! isfinite (l)
119
+ throw (DomainError (lazy " Loss is $l on data item $i, stopping training" ))
120
+ end
121
+ opt, model2 = Optimisers. update! (opt, model. val, gs[1 ])
122
+ model = Enzyme. Duplicated (model2, model. dval)
123
+ else
124
+ Zygote. withgradient (m -> loss (m, d_splat... ), model)
116
125
117
- _applyloss (loss, model, d... ) = loss (model, d... )
126
+ if ! isfinite (l)
127
+ throw (DomainError (lazy " Loss is $l on data item $i, stopping training" ))
128
+ end
118
129
119
- """
120
- train_enzyme!(loss, model_and_shadow, data, opt_state)
121
-
122
- Like [`train!](@ref), but gradient computed in place using [Enzyme](github.com/EnzymeAD/Enzyme.jl)
123
- """
124
- function train! (loss, model_and_shadow:: Enzyme.Duplicated , data, opt_state:: T ) where T<: Optimisers.AbstractRule
125
- @withprogress for (i,d) in enumerate (data)
126
- d_splat = d isa Tuple ? d : (d,)
127
- _make_zero! (model_and_shadow. dval)
128
- _, l = Enzyme. autodiff (Enzyme. ReverseWithPrimal, _applyloss, Enzyme. Active, Enzyme. Const (loss), model_and_shadow, map (Enzyme. Const, d_splat)... )
129
-
130
- if ! isfinite (l)
131
- throw (DomainError (lazy " Loss is $l on data item $i, stopping training" ))
132
- end
133
- opt_state, model = Optimisers. update! (opt_state, model_and_shadow. val, model_and_shadow. dval)
134
- model_and_shadow = Enzyme. Duplicated (model, model_and_shadow. dval)
135
- @logprogress Base. haslength (data) ? i/ length (data) : nothing
136
- end
137
- end
130
+ opt, model = Optimisers. update! (opt, model, gs[1 ])
138
131
139
- # Required per method ambiguity with
140
- # train!(loss, model, data, opt::Flux.Optimise.AbstractOptimiser; cb)
141
- # @ Flux ~/work/Flux.jl/Flux.jl/src/deprecations.jl:110
142
- function train! (loss, model_and_shadow:: Enzyme.Duplicated , data, opt_state:: Flux.Optimise.AbstractOptimiser )
143
- @withprogress for (i,d) in enumerate (data)
144
- d_splat = d isa Tuple ? d : (d,)
145
- _make_zero! (model_and_shadow. dval)
146
- _, l = Enzyme. autodiff (Enzyme. ReverseWithPrimal, _applyloss, Enzyme. Active, Enzyme. Const (loss), model_and_shadow, map (Enzyme. Const, d_splat)... )
147
-
148
- if ! isfinite (l)
149
- throw (DomainError (lazy " Loss is $l on data item $i, stopping training" ))
150
132
end
151
- opt_state, model = Optimisers. update! (opt_state, model_and_shadow. val, model_and_shadow. dval)
152
- model_and_shadow = Enzyme. Duplicated (model, model_and_shadow. dval)
153
133
@logprogress Base. haslength (data) ? i/ length (data) : nothing
154
134
end
155
135
end
156
136
157
-
158
137
# This method let you use Optimisers.Descent() without setup, when there is no state
159
138
function train! (loss, model, data, rule:: Optimisers.AbstractRule ; cb = nothing )
160
139
train! (loss, model, data, _rule_to_state (model, rule); cb)
0 commit comments