Skip to content

Commit 419b99a

Browse files
committed
Fix method ambiguity and skip some test for enzyme
1 parent 7b8309d commit 419b99a

File tree

3 files changed

+21
-15
lines changed

3 files changed

+21
-15
lines changed

src/deprecations.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,10 @@ train!(loss, ps::Params, data, opt::Optimisers.AbstractRule; cb=nothing) = error
107107
But better to use the new explicit style, in which `m` itself is the 2nd argument.
108108
""")
109109

110-
train!(loss, model, data, opt::Optimise.AbstractOptimiser; cb=nothing) = train!(loss, model, data, _old_to_new(opt); cb)
110+
train!(loss, model, data, opt::Optimise.AbstractOptimiser; cb=nothing) =
111+
train!(loss, model, data, _old_to_new(opt); cb)
112+
train!(loss, model::Enzyme.Duplicated, data, opt::Optimise.AbstractOptimiser; cb=nothing) =
113+
train!(loss, model, data, _old_to_new(opt); cb)
111114

112115
# Next, to use the new `setup` with the still-exported old-style `Adam` etc:
113116
import .Train: setup

src/train.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ function train!(loss, model, data, opt; cb = nothing)
110110
For more control use a loop with `gradient` and `update!`.""")
111111
@withprogress for (i,d) in enumerate(data)
112112
d_splat = d isa Tuple ? d : (d,)
113-
113+
114114
l, gs = Zygote.withgradient(m -> loss(m, d_splat...), model)
115115

116116
if !isfinite(l)
@@ -127,7 +127,7 @@ function train!(loss, model::Enzyme.Duplicated, data, opt; cb = nothing)
127127
For more control use a loop with `gradient` and `update!`.""")
128128
@withprogress for (i,d) in enumerate(data)
129129
d_splat = d isa Tuple ? d : (d,)
130-
130+
131131
_make_zero!(model.dval)
132132
_, l = Enzyme.autodiff(Enzyme.ReverseWithPrimal, _applyloss, Enzyme.Active, Enzyme.Const(loss), model, map(Enzyme.Const, d_splat)...)
133133

test/train.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,15 @@ end
8484
y1 = m(x) # before
8585

8686
# Implicit gradient
87-
gold = gradient(() -> m(x), Flux.params(m))
87+
gold = Zygote.gradient(() -> m(x), Flux.params(m))
8888
@test gold isa Flux.Zygote.Grads
8989
@test_throws ErrorException Flux.update!(Flux.Adam(), m, gold) # friendly
9090
Flux.update!(Flux.Adam(), Flux.params(m), gold)
9191
y2 = m(x)
9292
@test y2 < y1
9393

9494
# Explicit gradient
95-
gs = gradient(marg -> marg(x), m)
95+
gs = Zygote.gradient(marg -> marg(x), m)
9696
@test gs isa Tuple
9797
@test_throws ErrorException Flux.update!(Flux.Adam(), Flux.params(m), gs) # friendly
9898
@test_throws ErrorException Flux.update!(Flux.Adam(), Flux.params(m), gs[1]) # friendly
@@ -133,17 +133,20 @@ for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme"))
133133
diff1 = model.weight .- init_weight
134134

135135
# Take 2: the same, but with Flux.params. Was broken for a bit, no tests!
136-
model.weight .= init_weight
137-
model.bias .= 0
138-
pen2(x::AbstractArray) = sum(abs2, x)/2
139-
opt = Flux.setup(Adam(0.1), model)
140-
trainfn!(model, data, opt) do m, x, y
141-
err = Flux.mse(m(x), y)
142-
l2 = sum(pen2, Flux.params(m))
143-
err + 0.33 * l2
136+
# skipping this test for Enzyme cause implicit params is unsupported
137+
if name == "Zygote"
138+
model.weight .= init_weight
139+
model.bias .= 0
140+
pen2(x::AbstractArray) = sum(abs2, x)/2
141+
opt = Flux.setup(Adam(0.1), model)
142+
trainfn!(model, data, opt) do m, x, y
143+
err = Flux.mse(m(x), y)
144+
l2 = sum(pen2, Flux.params(m))
145+
err + 0.33 * l2
146+
end
147+
diff2 = model.weight .- init_weight
148+
@test diff1 diff2
144149
end
145-
diff2 = model.weight .- init_weight
146-
@test diff1 diff2
147150

148151
# Take 3: using WeightDecay instead. Need the /2 above, to match exactly.
149152
model.weight .= init_weight

0 commit comments

Comments
 (0)