Skip to content

Commit 0be4401

Browse files
authored
Merge pull request #2446 from wsmoses/master
Add Enzyme train function
2 parents f3021ba + 2796aac commit 0be4401

File tree

7 files changed

+95
-28
lines changed

7 files changed

+95
-28
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release.
44

5+
## v0.14.17
6+
* Add [support for Enzyme](https://github.com/FluxML/Flux.jl/pull/2446) with `Flux.train!`.
7+
58
## v0.14.13
69
* New macro `Flux.@layer` which should be used in place of `@functor`.
710
This also adds `show` methods for pretty printing.

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
name = "Flux"
22
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
3-
version = "0.14.16"
3+
version = "0.14.17"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
88
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
9+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
910
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1112
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"

src/deprecations.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,10 @@ train!(loss, ps::Params, data, opt::Optimisers.AbstractRule; cb=nothing) = error
107107
But better to use the new explicit style, in which `m` itself is the 2nd argument.
108108
""")
109109

110-
train!(loss, model, data, opt::Optimise.AbstractOptimiser; cb=nothing) = train!(loss, model, data, _old_to_new(opt); cb)
110+
train!(loss, model, data, opt::Optimise.AbstractOptimiser; cb=nothing) =
111+
train!(loss, model, data, _old_to_new(opt); cb)
112+
train!(loss, model::Enzyme.Duplicated, data, opt::Optimise.AbstractOptimiser; cb=nothing) =
113+
train!(loss, model, data, _old_to_new(opt); cb)
111114

112115
# Next, to use the new `setup` with the still-exported old-style `Adam` etc:
113116
import .Train: setup

src/functor.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using LinearAlgebra: Cholesky
33
using Zygote: IdSet
44
import Functors: Functors, @functor, functor, fmap, isleaf
55
using SparseArrays: AbstractSparseArray
6+
using Enzyme
67

78
"""
89
testmode!(model, [mode]) -> model

src/losses/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import Enzyme
2+
13
"""
24
xlogx(x)
35
@@ -36,3 +38,4 @@ end
3638
_check_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1
3739

3840
ChainRulesCore.@non_differentiable _check_sizes(ŷ::Any, y::Any)
41+
Enzyme.EnzymeRules.inactive(::typeof(_check_sizes), args...) = true

src/train.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using Optimisers: Optimisers
55
using Functors: fmap, fmapstructure
66
using ..Flux: Flux # used only in docstring
77
import ..Flux.Optimise: train!, update! # during 0.13, we add methods to the old functions
8+
import Enzyme
89

910
export setup, train!
1011

@@ -52,6 +53,12 @@ function setup(rule::Optimisers.AbstractRule, model)
5253
state
5354
end
5455

56+
_make_zero_internal!(x::AbstractArray) = fill!(x, 0)
57+
_make_zero_internal!(x) = x
58+
_make_zero!(model) = fmap(_make_zero_internal!, model)
59+
60+
_applyloss(loss, model, d...) = loss(model, d...)
61+
5562
"""
5663
train!(loss, model, data, opt_state)
5764
@@ -60,6 +67,9 @@ according to a particular optimisation rule encoded in `opt_state`.
6067
Iterates through `data` once, evaluating for each `d in data` either
6168
`loss(model, d...)` if `d isa Tuple`, or else `loss(model, d)` for other `d`.
6269
70+
If `model` is an Enzyme.Duplicated, gradients will be computed with Enzyme,
71+
otherwise they will be computed with Zygote.
72+
6373
For example, with these definitions...
6474
```
6575
data = [(x1, y1), (x2, y2), (x3, y3)]
@@ -100,11 +110,33 @@ function train!(loss, model, data, opt; cb = nothing)
100110
For more control use a loop with `gradient` and `update!`.""")
101111
@withprogress for (i,d) in enumerate(data)
102112
d_splat = d isa Tuple ? d : (d,)
113+
103114
l, gs = Zygote.withgradient(m -> loss(m, d_splat...), model)
115+
104116
if !isfinite(l)
105117
throw(DomainError(lazy"Loss is $l on data item $i, stopping training"))
106118
end
119+
107120
opt, model = Optimisers.update!(opt, model, gs[1])
121+
122+
@logprogress Base.haslength(data) ? i/length(data) : nothing
123+
end
124+
end
125+
function train!(loss, model::Enzyme.Duplicated, data, opt; cb = nothing)
126+
isnothing(cb) || error("""train! does not support callback functions.
127+
For more control use a loop with `gradient` and `update!`.""")
128+
@withprogress for (i,d) in enumerate(data)
129+
d_splat = d isa Tuple ? d : (d,)
130+
131+
_make_zero!(model.dval)
132+
_, l = Enzyme.autodiff(Enzyme.ReverseWithPrimal, _applyloss, Enzyme.Active, Enzyme.Const(loss), model, map(Enzyme.Const, d_splat)...)
133+
134+
if !isfinite(l)
135+
throw(DomainError(lazy"Loss is $l on data item $i, stopping training"))
136+
end
137+
opt, model2 = Optimisers.update!(opt, model.val, model.dval)
138+
model = Enzyme.Duplicated(model2, model.dval)
139+
108140
@logprogress Base.haslength(data) ? i/length(data) : nothing
109141
end
110142
end
@@ -113,6 +145,9 @@ end
113145
function train!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing)
114146
train!(loss, model, data, _rule_to_state(model, rule); cb)
115147
end
148+
function train!(loss, model::Enzyme.Duplicated, data, rule::Optimisers.AbstractRule; cb = nothing)
149+
train!(loss, model, data, _rule_to_state(model, rule); cb)
150+
end
116151

117152
function _rule_to_state(model, rule::Optimisers.AbstractRule)
118153
state = setup(rule, model)

test/train.jl

Lines changed: 47 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,14 @@ import Optimisers
44

55
using Test
66
using Random
7+
using Enzyme
78

8-
@testset "Explicit Flux.train! with Zygote" begin
9+
function train_enzyme!(fn, model, args...; kwargs...)
10+
Flux.train!(fn, Duplicated(model, Enzyme.make_zero(model)), args...; kwargs...)
11+
end
12+
13+
for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme"))
14+
@testset "Explicit Flux.train! with $name" begin
915
Random.seed!(84)
1016
w = randn(10, 10)
1117
w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset.
@@ -18,31 +24,40 @@ using Random
1824
@test loss(model, rand(10, 10)) > 1
1925

2026
opt = Flux.setup(rule, model)
21-
Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
27+
trainfn!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
2228
@test loss(model, rand(10, 10)) < 0.01
2329
end
2430

2531
# Test direct use of Optimisers.jl rule, only really OK for `Descent`:
32+
# Enzyme doesn't work with un-initialized atm, presumably due to trainmode?
33+
if name != "Enzyme"
2634
@testset "without setup, $opt" for opt in [Descent(0.1), Optimisers.Descent(0.1), Optimisers.Adam()]
2735
loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
2836
model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
2937
@test loss(model, rand(10, 10)) > 1
30-
Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
38+
trainfn!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
3139
@test loss(model, rand(10, 10)) < 0.01
3240
end
41+
end
42+
end
3343
end
3444

35-
@testset "Explicit Flux.train! features" begin
45+
for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme"))
46+
@testset "Explicit Flux.train! features with $name" begin
3647
@testset "Stop on NaN" begin
3748
m1 = Dense(1 => 1)
3849
m1.weight .= 0
39-
CNT = 0
40-
@test_throws DomainError Flux.train!(m1, tuple.(1:100), Descent(0.1)) do m, i
41-
CNT += 1
50+
CNT = Ref(0)
51+
@test_throws DomainError trainfn!(m1, tuple.(1:100), Descent(0.1)) do m, i
52+
CNT[] += 1
4253
(i == 51 ? NaN32 : 1f0) * sum(m([1.0]))
4354
end
44-
@test CNT == 51 # stopped early
45-
@test m1.weight[1] -5 # did not corrupt weights
55+
@test CNT[] == 51 # stopped early
56+
if name != "Enzyme"
57+
@test m1.weight[1] -5 # did not corrupt weights
58+
else
59+
@test m1.weight[1] 0.0 # did not corrupt weights
60+
end
4661
end
4762

4863
@testset "non-tuple data" begin
@@ -51,32 +66,33 @@ end
5166
loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
5267
model = (weight=copy(w2), bias=zeros(10))
5368
opt = Flux.setup(AdamW(), model)
54-
Flux.train!(loss, model, (rand(10) for _ in 1: 10^5), opt)
69+
trainfn!(loss, model, (rand(10) for _ in 1: 10^5), opt)
5570
@test loss(model, rand(10, 10)) < 0.01
5671
end
5772

5873
@testset "callbacks give helpful error" begin
5974
m1 = Dense(1 => 1)
6075
cb = () -> println("this should not be printed")
61-
@test_throws ErrorException Flux.train!((args...,) -> 1, m1, [(1,2)], Descent(0.1); cb)
76+
@test_throws ErrorException trainfn!((args...,) -> 1, m1, [(1,2)], Descent(0.1); cb)
6277
end
6378
end
79+
end
6480

6581
@testset "Explicit Flux.update! features" begin
6682
m = Chain(Dense(2=>3, tanh), Dense(3=>1), only)
6783
x = rand(2)
6884
y1 = m(x) # before
6985

7086
# Implicit gradient
71-
gold = gradient(() -> m(x), Flux.params(m))
87+
gold = Zygote.gradient(() -> m(x), Flux.params(m))
7288
@test gold isa Flux.Zygote.Grads
7389
@test_throws ErrorException Flux.update!(Flux.Adam(), m, gold) # friendly
7490
Flux.update!(Flux.Adam(), Flux.params(m), gold)
7591
y2 = m(x)
7692
@test y2 < y1
7793

7894
# Explicit gradient
79-
gs = gradient(marg -> marg(x), m)
95+
gs = Zygote.gradient(marg -> marg(x), m)
8096
@test gs isa Tuple
8197
@test_throws ErrorException Flux.update!(Flux.Adam(), Flux.params(m), gs) # friendly
8298
@test_throws ErrorException Flux.update!(Flux.Adam(), Flux.params(m), gs[1]) # friendly
@@ -98,7 +114,8 @@ end
98114
@test y5 < y4
99115
end
100116

101-
@testset "L2 regularisation" begin
117+
for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme"))
118+
@testset "L2 regularisation with $name" begin
102119
# New docs claim an exact equivalent. It's a bit long to put the example in there,
103120
# but perhaps the tests should contain it.
104121

@@ -108,36 +125,40 @@ end
108125

109126
# Take 1: explicitly add a penalty in the loss function
110127
opt = Flux.setup(Adam(0.1), model)
111-
Flux.train!(model, data, opt) do m, x, y
128+
trainfn!(model, data, opt) do m, x, y
112129
err = Flux.mse(m(x), y)
113130
l2 = sum(abs2, m.weight)/2 + sum(abs2, m.bias)/2
114131
err + 0.33 * l2
115132
end
116133
diff1 = model.weight .- init_weight
117134

118135
# Take 2: the same, but with Flux.params. Was broken for a bit, no tests!
119-
model.weight .= init_weight
120-
model.bias .= 0
121-
pen2(x::AbstractArray) = sum(abs2, x)/2
122-
opt = Flux.setup(Adam(0.1), model)
123-
Flux.train!(model, data, opt) do m, x, y
124-
err = Flux.mse(m(x), y)
125-
l2 = sum(pen2, Flux.params(m))
126-
err + 0.33 * l2
136+
# skipping this test for Enzyme cause implicit params is unsupported
137+
if name == "Zygote"
138+
model.weight .= init_weight
139+
model.bias .= 0
140+
pen2(x::AbstractArray) = sum(abs2, x)/2
141+
opt = Flux.setup(Adam(0.1), model)
142+
trainfn!(model, data, opt) do m, x, y
143+
err = Flux.mse(m(x), y)
144+
l2 = sum(pen2, Flux.params(m))
145+
err + 0.33 * l2
146+
end
147+
diff2 = model.weight .- init_weight
148+
@test diff1 diff2
127149
end
128-
diff2 = model.weight .- init_weight
129-
@test diff1 diff2
130150

131151
# Take 3: using WeightDecay instead. Need the /2 above, to match exactly.
132152
model.weight .= init_weight
133153
model.bias .= 0
134154
decay_opt = Flux.setup(OptimiserChain(WeightDecay(0.33), Adam(0.1)), model);
135-
Flux.train!(model, data, decay_opt) do m, x, y
155+
trainfn!(model, data, decay_opt) do m, x, y
136156
Flux.mse(m(x), y)
137157
end
138158
diff3 = model.weight .- init_weight
139159
@test diff1 diff3
140160
end
161+
end
141162

142163
@testset "Flux.setup bugs" begin
143164
# https://github.com/FluxML/Flux.jl/issues/2144

0 commit comments

Comments
 (0)