Skip to content

Commit 7b8309d

Browse files
committed
Rearrange dispatch for enzyme train
1 parent b329ef1 commit 7b8309d

File tree

2 files changed

+24
-53
lines changed

2 files changed

+24
-53
lines changed

src/functor.jl

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -90,45 +90,6 @@ function params!(p::Params, x, seen = IdSet())
9090
end
9191
end
9292

93-
function Enzyme.EnzymeRules.augmented_primal(config, func::Enzyme.Const{typeof(params!)}, ::Type{RT},
94-
p::Enzyme.Annotation,
95-
x::Enzyme.Annotation,
96-
seen::Enzyme.Annotation) where {RT}
97-
98-
res = func.val(p.val, x.val, seen.val)
99-
100-
primal = if EnzymeRules.needs_primal(config)
101-
res
102-
else
103-
nothing
104-
end
105-
106-
sres = if EnzymeRules.width(config) == 1
107-
func.val(p.dval, x.dval, seen isa Const ? IdSet() : seen.dval)
108-
else
109-
ntuple(Val(EnzymeRules.width(config))) do i
110-
Base.@_inline_meta
111-
func.val(p.dval[i], x.dval[i], seen isa Const ? IdSet() : seen.dval[i])
112-
end
113-
end
114-
115-
shadow = if EnzymeRules.needs_shadow(config)
116-
sres
117-
else
118-
nothing
119-
end
120-
121-
return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
122-
end
123-
124-
function Enzyme.EnzymeRules.reverse(config, func::Enzyme.Const{typeof(params!)}, ::Type{RT}, cache,
125-
p::Enzyme.Annotation,
126-
x::Enzyme.Annotation,
127-
seen::Enzyme.Annotation) where {RT}
128-
129-
return (nothing, nothing, nothing)
130-
end
131-
13293
"""
13394
params(model)
13495
params(layers...)

src/train.jl

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -111,25 +111,32 @@ function train!(loss, model, data, opt; cb = nothing)
111111
@withprogress for (i,d) in enumerate(data)
112112
d_splat = d isa Tuple ? d : (d,)
113113

114-
if model isa Enzyme.Duplicated
115-
_make_zero!(model.dval)
116-
_, l = Enzyme.autodiff(Enzyme.ReverseWithPrimal, _applyloss, Enzyme.Active, Enzyme.Const(loss), model, map(Enzyme.Const, d_splat)...)
114+
l, gs = Zygote.withgradient(m -> loss(m, d_splat...), model)
117115

118-
if !isfinite(l)
119-
throw(DomainError(lazy"Loss is $l on data item $i, stopping training"))
120-
end
121-
opt, model2 = Optimisers.update!(opt, model.val, model.dval)
122-
model = Enzyme.Duplicated(model2, model.dval)
123-
else
124-
l, gs = Zygote.withgradient(m -> loss(m, d_splat...), model)
116+
if !isfinite(l)
117+
throw(DomainError(lazy"Loss is $l on data item $i, stopping training"))
118+
end
125119

126-
if !isfinite(l)
127-
throw(DomainError(lazy"Loss is $l on data item $i, stopping training"))
128-
end
120+
opt, model = Optimisers.update!(opt, model, gs[1])
129121

130-
opt, model = Optimisers.update!(opt, model, gs[1])
122+
@logprogress Base.haslength(data) ? i/length(data) : nothing
123+
end
124+
end
125+
function train!(loss, model::Enzyme.Duplicated, data, opt; cb = nothing)
126+
isnothing(cb) || error("""train! does not support callback functions.
127+
For more control use a loop with `gradient` and `update!`.""")
128+
@withprogress for (i,d) in enumerate(data)
129+
d_splat = d isa Tuple ? d : (d,)
130+
131+
_make_zero!(model.dval)
132+
_, l = Enzyme.autodiff(Enzyme.ReverseWithPrimal, _applyloss, Enzyme.Active, Enzyme.Const(loss), model, map(Enzyme.Const, d_splat)...)
131133

134+
if !isfinite(l)
135+
throw(DomainError(lazy"Loss is $l on data item $i, stopping training"))
132136
end
137+
opt, model2 = Optimisers.update!(opt, model.val, model.dval)
138+
model = Enzyme.Duplicated(model2, model.dval)
139+
133140
@logprogress Base.haslength(data) ? i/length(data) : nothing
134141
end
135142
end
@@ -138,6 +145,9 @@ end
138145
function train!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing)
139146
train!(loss, model, data, _rule_to_state(model, rule); cb)
140147
end
148+
function train!(loss, model::Enzyme.Duplicated, data, rule::Optimisers.AbstractRule; cb = nothing)
149+
train!(loss, model, data, _rule_to_state(model, rule); cb)
150+
end
141151

142152
function _rule_to_state(model, rule::Optimisers.AbstractRule)
143153
state = setup(rule, model)

0 commit comments

Comments
 (0)