Skip to content

Commit 9b9f950

Browse files
start testing Enzyme (#2392)
* start testing * add tests for Enzyme * update runtests * comparison with finitedifferences * cl/enzyme * tests passing * cleanup * add FiniteDifferences to extra * check_grad -> test_grad
1 parent d811e8b commit 9b9f950

File tree

3 files changed

+197
-1
lines changed

3 files changed

+197
-1
lines changed

Project.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ Adapt = "3, 4"
4040
CUDA = "4, 5"
4141
ChainRulesCore = "1.12"
4242
Compat = "4.10.0"
43+
Enzyme = "0.11"
44+
FiniteDifferences = "0.12"
4345
Functors = "0.4"
4446
MLUtils = "0.4"
4547
MacroTools = "0.5"
@@ -62,7 +64,9 @@ BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
6264
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
6365
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
6466
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
67+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
6568
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
69+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
6670
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
6771
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
6872
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
@@ -71,4 +75,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7175
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
7276

7377
[targets]
74-
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "BSON", "Pkg", "CUDA", "cuDNN", "Metal", "AMDGPU"]
78+
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays",
79+
"ComponentArrays", "BSON", "Pkg", "CUDA", "cuDNN", "Metal", "AMDGPU",
80+
"Enzyme", "FiniteDifferences"]

test/ext_enzyme/enzyme.jl

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
using Test
2+
using Flux
3+
4+
using Enzyme
5+
using Functors
6+
using FiniteDifferences
7+
using CUDA
8+
9+
Enzyme.API.typeWarning!(false) # suppresses a warning with Bilinear https://github.com/EnzymeAD/Enzyme.jl/issues/1341
10+
Enzyme.API.runtimeActivity!(true) # for Enzyme debugging
11+
# Enzyme.Compiler.bitcode_replacement!(false)
12+
13+
_make_zero(x::Union{Number,AbstractArray}) = zero(x)
14+
_make_zero(x) = x
15+
make_zero(model) = fmap(_make_zero, model)
16+
## make_differential(model) = fmapstructure(make_zero, model) # NOT SUPPORTED, See https://github.com/EnzymeAD/Enzyme.jl/issues/1329
17+
18+
function gradient_fd(f, x...)
19+
x = [cpu(x) for x in x]
20+
ps_and_res = [x isa AbstractArray ? (x, identity) : Flux.destructure(x) for x in x]
21+
ps = [f64(x[1]) for x in ps_and_res]
22+
res = [x[2] for x in ps_and_res]
23+
fdm = FiniteDifferences.central_fdm(5, 1)
24+
gs = FiniteDifferences.grad(fdm, (ps...) -> f((re(p) for (p,re) in zip(ps, res))...), ps...)
25+
return ((re(g) for (re, g) in zip(res, gs))...,)
26+
end
27+
28+
function gradient_ez(f, x...)
29+
args = []
30+
for x in x
31+
if x isa Number
32+
push!(args, Active(x))
33+
else
34+
push!(args, Duplicated(x, make_zero(x)))
35+
end
36+
end
37+
ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
38+
g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
39+
return g
40+
end
41+
42+
function test_grad(g1, g2; broken=false)
43+
fmap_with_path(g1, g2) do kp, x, y
44+
:state kp && return # ignore RNN and LSTM state
45+
if x isa AbstractArray{<:Number}
46+
# @show kp
47+
@test x y rtol=1e-2 atol=1e-6 broken=broken
48+
end
49+
return x
50+
end
51+
end
52+
53+
function test_enzyme_grad(loss, model, x)
54+
Flux.trainmode!(model)
55+
l = loss(model, x)
56+
@test loss(model, x) == l # Check loss doesn't change with multiple runs
57+
58+
grads_fd = gradient_fd(loss, model, x) |> cpu
59+
grads_flux = Flux.gradient(loss, model, x) |> cpu
60+
grads_enzyme = gradient_ez(loss, model, x) |> cpu
61+
62+
# test_grad(grads_flux, grads_enzyme)
63+
test_grad(grads_fd, grads_enzyme)
64+
end
65+
66+
@testset "gradient_ez" begin
67+
@testset "number and arrays" begin
68+
f(x, y) = sum(x.^2) + y^3
69+
x = Float32[1, 2, 3]
70+
y = 3f0
71+
g = gradient_ez(f, x, y)
72+
@test g[1] isa Array{Float32}
73+
@test g[2] isa Float32
74+
@test g[1] 2x
75+
@test g[2] 3*y^2
76+
end
77+
78+
@testset "struct" begin
79+
struct SimpleDense{W, B, F}
80+
weight::W
81+
bias::B
82+
σ::F
83+
end
84+
SimpleDense(in::Integer, out::Integer; σ=identity) = SimpleDense(randn(Float32, out, in), zeros(Float32, out), σ)
85+
(m::SimpleDense)(x) = m.σ.(m.weight * x .+ m.bias)
86+
@functor SimpleDense
87+
88+
model = SimpleDense(2, 4)
89+
x = randn(Float32, 2)
90+
loss(model, x) = sum(model(x))
91+
92+
g = gradient_ez(loss, model, x)
93+
@test g[1] isa SimpleDense
94+
@test g[2] isa Array{Float32}
95+
@test g[1].weight isa Array{Float32}
96+
@test g[1].bias isa Array{Float32}
97+
@test g[1].weight ones(Float32, 4, 1) .* x'
98+
@test g[1].bias ones(Float32, 4)
99+
end
100+
end
101+
102+
@testset "Models" begin
103+
function loss(model, x)
104+
Flux.reset!(model)
105+
sum(model(x))
106+
end
107+
108+
models_xs = [
109+
(Dense(2, 4), randn(Float32, 2), "Dense"),
110+
(Chain(Dense(2, 4, relu), Dense(4, 3)), randn(Float32, 2), "Chain(Dense, Dense)"),
111+
(f64(Chain(Dense(2, 4), Dense(4, 2))), randn(Float64, 2, 1), "f64(Chain(Dense, Dense))"),
112+
(Flux.Scale([1.0f0 2.0f0 3.0f0 4.0f0], true, abs2), randn(Float32, 2), "Flux.Scale"),
113+
(Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 1), "Conv"),
114+
(Chain(Conv((3, 3), 2 => 3, relu), Conv((3, 3), 3 => 1, relu)), rand(Float32, 5, 5, 2, 1), "Chain(Conv, Conv)"),
115+
(Chain(Conv((4, 4), 2 => 2, pad=SamePad()), MeanPool((5, 5), pad=SamePad())), rand(Float32, 5, 5, 2, 2), "Chain(Conv, MeanPool)"),
116+
(Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 1), "Maxout"),
117+
(RNN(3 => 2), randn(Float32, 3, 2), "RNN"),
118+
(Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"),
119+
(LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"),
120+
(Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"),
121+
(SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"),
122+
(Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"),
123+
]
124+
125+
for (model, x, name) in models_xs
126+
@testset "check grad $name" begin
127+
println("testing $name")
128+
test_enzyme_grad(loss, model, x)
129+
end
130+
end
131+
end
132+
133+
@testset "Recurrence Tests" begin
134+
function loss(model, x)
135+
Flux.reset!(model)
136+
for i in 1:3
137+
x = model(x)
138+
end
139+
return sum(x)
140+
end
141+
142+
models_xs = [
143+
(RNN(3 => 3), randn(Float32, 3, 2), "RNN"),
144+
(LSTM(3 => 3), randn(Float32, 3, 2), "LSTM"),
145+
# TESTS BELOW ARE BROKEN FOR ZYGOTE BUT CORRECT FOR ENZYME!
146+
(Chain(RNN(3 => 5), RNN(5 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"),
147+
(Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"),
148+
]
149+
150+
for (model, x, name) in models_xs
151+
@testset "check grad $name" begin
152+
println("testing $name")
153+
test_enzyme_grad(loss, model, x)
154+
end
155+
end
156+
end
157+
158+
@testset "Broken Models" begin
159+
function loss(model, x)
160+
Flux.reset!(model)
161+
sum(model(x))
162+
end
163+
164+
device = Flux.get_device()
165+
166+
models_xs = [
167+
(GRU(3 => 5), randn(Float32, 3, 10), "GRU"),
168+
(ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"),
169+
]
170+
171+
for (model, x, name) in models_xs
172+
@testset "check grad $name" begin
173+
println("testing $name")
174+
broken = false
175+
try
176+
test_enzyme_grad(loss, model, x)
177+
catch e
178+
println(e)
179+
broken = true
180+
end
181+
@test broken
182+
end
183+
end
184+
end
185+

test/runtests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,9 @@ Random.seed!(0)
116116
@info "Skipping Metal tests, set FLUX_TEST_METAL=true to run them."
117117
end
118118

119+
@testset "Enzyme" begin
120+
import Enzyme
121+
include("ext_enzyme/enzyme.jl")
122+
end
123+
119124
end

0 commit comments

Comments
 (0)