Skip to content

Commit 60bc744

Browse files
authored
Enable remaining enzyme test (#2442)
1 parent 11f3fca commit 60bc744

File tree

2 files changed

+2
-34
lines changed

2 files changed

+2
-34
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Functors = "0.4"
4646
MLUtils = "0.4"
4747
MacroTools = "0.5"
4848
Metal = "0.5, 1"
49-
NNlib = "0.9.14"
49+
NNlib = "0.9.15"
5050
OneHotArrays = "0.2.4"
5151
Optimisers = "0.3.3"
5252
Preferences = "1"

test/ext_enzyme/enzyme.jl

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,6 @@ using Functors
66
using FiniteDifferences
77
using CUDA
88

9-
Enzyme.API.typeWarning!(false) # suppresses a warning with Bilinear https://github.com/EnzymeAD/Enzyme.jl/issues/1341
10-
Enzyme.API.runtimeActivity!(true) # for Enzyme debugging
11-
# Enzyme.Compiler.bitcode_replacement!(false)
12-
139
_make_zero(x::Union{Number,AbstractArray}) = zero(x)
1410
_make_zero(x) = x
1511
make_zero(model) = fmap(_make_zero, model)
@@ -121,6 +117,7 @@ end
121117
(SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"),
122118
(Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"),
123119
(GRU(3 => 5), randn(Float32, 3, 10), "GRU"),
120+
(ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"),
124121
]
125122

126123
for (model, x, name) in models_xs
@@ -155,32 +152,3 @@ end
155152
end
156153
end
157154
end
158-
159-
@testset "Broken Models" begin
160-
function loss(model, x)
161-
Flux.reset!(model)
162-
sum(model(x))
163-
end
164-
165-
device = Flux.get_device()
166-
167-
models_xs = [
168-
# Pending https://github.com/FluxML/NNlib.jl/issues/565
169-
(ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"),
170-
]
171-
172-
for (model, x, name) in models_xs
173-
@testset "check grad $name" begin
174-
println("testing $name")
175-
broken = false
176-
try
177-
test_enzyme_grad(loss, model, x)
178-
catch e
179-
println(e)
180-
broken = true
181-
end
182-
@test broken
183-
end
184-
end
185-
end
186-

0 commit comments

Comments
 (0)