@@ -111,25 +111,32 @@ function train!(loss, model, data, opt; cb = nothing)
111
111
@withprogress for (i,d) in enumerate (data)
112
112
d_splat = d isa Tuple ? d : (d,)
113
113
114
- if model isa Enzyme. Duplicated
115
- _make_zero! (model. dval)
116
- _, l = Enzyme. autodiff (Enzyme. ReverseWithPrimal, _applyloss, Enzyme. Active, Enzyme. Const (loss), model, map (Enzyme. Const, d_splat)... )
114
+ l, gs = Zygote. withgradient (m -> loss (m, d_splat... ), model)
117
115
118
- if ! isfinite (l)
119
- throw (DomainError (lazy " Loss is $l on data item $i, stopping training" ))
120
- end
121
- opt, model2 = Optimisers. update! (opt, model. val, model. dval)
122
- model = Enzyme. Duplicated (model2, model. dval)
123
- else
124
- l, gs = Zygote. withgradient (m -> loss (m, d_splat... ), model)
116
+ if ! isfinite (l)
117
+ throw (DomainError (lazy " Loss is $l on data item $i, stopping training" ))
118
+ end
125
119
126
- if ! isfinite (l)
127
- throw (DomainError (lazy " Loss is $l on data item $i, stopping training" ))
128
- end
120
+ opt, model = Optimisers. update! (opt, model, gs[1 ])
129
121
130
- opt, model = Optimisers. update! (opt, model, gs[1 ])
122
+ @logprogress Base. haslength (data) ? i/ length (data) : nothing
123
+ end
124
+ end
125
+ function train! (loss, model:: Enzyme.Duplicated , data, opt; cb = nothing )
126
+ isnothing (cb) || error (""" train! does not support callback functions.
127
+ For more control use a loop with `gradient` and `update!`.""" )
128
+ @withprogress for (i,d) in enumerate (data)
129
+ d_splat = d isa Tuple ? d : (d,)
130
+
131
+ _make_zero! (model. dval)
132
+ _, l = Enzyme. autodiff (Enzyme. ReverseWithPrimal, _applyloss, Enzyme. Active, Enzyme. Const (loss), model, map (Enzyme. Const, d_splat)... )
131
133
134
+ if ! isfinite (l)
135
+ throw (DomainError (lazy " Loss is $l on data item $i, stopping training" ))
132
136
end
137
+ opt, model2 = Optimisers. update! (opt, model. val, model. dval)
138
+ model = Enzyme. Duplicated (model2, model. dval)
139
+
133
140
@logprogress Base. haslength (data) ? i/ length (data) : nothing
134
141
end
135
142
end
138
145
function train! (loss, model, data, rule:: Optimisers.AbstractRule ; cb = nothing )
139
146
train! (loss, model, data, _rule_to_state (model, rule); cb)
140
147
end
148
+ function train! (loss, model:: Enzyme.Duplicated , data, rule:: Optimisers.AbstractRule ; cb = nothing )
149
+ train! (loss, model, data, _rule_to_state (model, rule); cb)
150
+ end
141
151
142
152
function _rule_to_state (model, rule:: Optimisers.AbstractRule )
143
153
state = setup (rule, model)
0 commit comments