Skip to content

Commit 0cc6190

Browse files
committed
Fixup
1 parent c01ab6f commit 0cc6190

File tree

2 files changed

+27
-47
lines changed

2 files changed

+27
-47
lines changed

src/losses/utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import Enzyme
2+
13
"""
24
xlogx(x)
35
@@ -36,5 +38,4 @@ end
3638
_check_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1
3739

3840
ChainRulesCore.@non_differentiable _check_sizes(ŷ::Any, y::Any)
39-
import Enzyme
4041
Enzyme.EnzymeRules.inactive(::typeof(_check_sizes), args...) = true

src/train.jl

Lines changed: 25 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using ..Flux: Flux # used only in docstring
77
import ..Flux.Optimise: train!, update! # during 0.13, we add methods to the old functions
88
import Enzyme
99

10-
export setup, train!, train_enzyme!
10+
export setup, train!
1111

1212
using ProgressLogging: @progress, @withprogress, @logprogress
1313
using Zygote: Zygote, Params
@@ -53,6 +53,12 @@ function setup(rule::Optimisers.AbstractRule, model)
5353
state
5454
end
5555

56+
_make_zero_internal!(x::AbstractArray) = fill!(x, 0)
57+
_make_zero_internal!(x) = x
58+
_make_zero!(model) = fmap(_make_zero_internal!, model)
59+
60+
_applyloss(loss, model, d...) = loss(model, d...)
61+
5662
"""
5763
train!(loss, model, data, opt_state)
5864
@@ -61,6 +67,9 @@ according to a particular optimisation rule encoded in `opt_state`.
6167
Iterates through `data` once, evaluating for each `d in data` either
6268
`loss(model, d...)` if `d isa Tuple`, or else `loss(model, d)` for other `d`.
6369
70+
If `model` is an Enzyme.Duplicated, gradients will be computed with Enzyme,
71+
otherwise they will be computed with Zygote.
72+
6473
For example, with these definitions...
6574
```
6675
data = [(x1, y1), (x2, y2), (x3, y3)]
@@ -101,60 +110,30 @@ function train!(loss, model, data, opt; cb = nothing)
101110
For more control use a loop with `gradient` and `update!`.""")
102111
@withprogress for (i,d) in enumerate(data)
103112
d_splat = d isa Tuple ? d : (d,)
104-
l, gs = Zygote.withgradient(m -> loss(m, d_splat...), model)
105-
if !isfinite(l)
106-
throw(DomainError(lazy"Loss is $l on data item $i, stopping training"))
107-
end
108-
opt, model = Optimisers.update!(opt, model, gs[1])
109-
@logprogress Base.haslength(data) ? i/length(data) : nothing
110-
end
111-
end
113+
114+
if model isa Enzyme.Duplicated
115+
_make_zero!(model.dval)
116+
_, l = Enzyme.autodiff(Enzyme.ReverseWithPrimal, _applyloss, Enzyme.Active, Enzyme.Const(loss), model, map(Enzyme.Const, d_splat)...)
112117

113-
_make_zero_internal!(x::AbstractArray) = fill!(x, 0)
114-
_make_zero_internal!(x) = x
115-
_make_zero!(model) = fmap(_make_zero_internal!, model)
118+
if !isfinite(l)
119+
throw(DomainError(lazy"Loss is $l on data item $i, stopping training"))
120+
end
121+
opt, model2 = Optimisers.update!(opt, model.val, gs[1])
122+
model = Enzyme.Duplicated(model2, model.dval)
123+
else
124+
Zygote.withgradient(m -> loss(m, d_splat...), model)
116125

117-
_applyloss(loss, model, d...) = loss(model, d...)
126+
if !isfinite(l)
127+
throw(DomainError(lazy"Loss is $l on data item $i, stopping training"))
128+
end
118129

119-
"""
120-
train_enzyme!(loss, model_and_shadow, data, opt_state)
121-
122-
Like [`train!](@ref), but gradient computed in place using [Enzyme](github.com/EnzymeAD/Enzyme.jl)
123-
"""
124-
function train!(loss, model_and_shadow::Enzyme.Duplicated, data, opt_state::T) where T<:Optimisers.AbstractRule
125-
@withprogress for (i,d) in enumerate(data)
126-
d_splat = d isa Tuple ? d : (d,)
127-
_make_zero!(model_and_shadow.dval)
128-
_, l = Enzyme.autodiff(Enzyme.ReverseWithPrimal, _applyloss, Enzyme.Active, Enzyme.Const(loss), model_and_shadow, map(Enzyme.Const, d_splat)...)
129-
130-
if !isfinite(l)
131-
throw(DomainError(lazy"Loss is $l on data item $i, stopping training"))
132-
end
133-
opt_state, model = Optimisers.update!(opt_state, model_and_shadow.val, model_and_shadow.dval)
134-
model_and_shadow = Enzyme.Duplicated(model, model_and_shadow.dval)
135-
@logprogress Base.haslength(data) ? i/length(data) : nothing
136-
end
137-
end
130+
opt, model = Optimisers.update!(opt, model, gs[1])
138131

139-
# Required per method ambiguity with
140-
# train!(loss, model, data, opt::Flux.Optimise.AbstractOptimiser; cb)
141-
# @ Flux ~/work/Flux.jl/Flux.jl/src/deprecations.jl:110
142-
function train!(loss, model_and_shadow::Enzyme.Duplicated, data, opt_state::Flux.Optimise.AbstractOptimiser)
143-
@withprogress for (i,d) in enumerate(data)
144-
d_splat = d isa Tuple ? d : (d,)
145-
_make_zero!(model_and_shadow.dval)
146-
_, l = Enzyme.autodiff(Enzyme.ReverseWithPrimal, _applyloss, Enzyme.Active, Enzyme.Const(loss), model_and_shadow, map(Enzyme.Const, d_splat)...)
147-
148-
if !isfinite(l)
149-
throw(DomainError(lazy"Loss is $l on data item $i, stopping training"))
150132
end
151-
opt_state, model = Optimisers.update!(opt_state, model_and_shadow.val, model_and_shadow.dval)
152-
model_and_shadow = Enzyme.Duplicated(model, model_and_shadow.dval)
153133
@logprogress Base.haslength(data) ? i/length(data) : nothing
154134
end
155135
end
156136

157-
158137
# This method let you use Optimisers.Descent() without setup, when there is no state
159138
function train!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing)
160139
train!(loss, model, data, _rule_to_state(model, rule); cb)

0 commit comments

Comments
 (0)