Skip to content

Commit 11f3fca

Browse files
authored
Enzyme: bump version and mark models as working [test] (#2439)
* Enzyme: bump version and mark models as working [test] * Update Project.toml * Update Project.toml * Update enzyme.jl * Mark transpose as not supported
1 parent 26c9acf commit 11f3fca

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ Adapt = "3, 4"
4040
CUDA = "4, 5"
4141
ChainRulesCore = "1.12"
4242
Compat = "4.10.0"
43-
Enzyme = "0.11"
43+
Enzyme = "0.12.4"
4444
FiniteDifferences = "0.12"
4545
Functors = "0.4"
4646
MLUtils = "0.4"

test/ext_enzyme/enzyme.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ end
120120
(Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"),
121121
(SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"),
122122
(Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"),
123+
(GRU(3 => 5), randn(Float32, 3, 10), "GRU"),
123124
]
124125

125126
for (model, x, name) in models_xs
@@ -164,7 +165,7 @@ end
164165
device = Flux.get_device()
165166

166167
models_xs = [
167-
(GRU(3 => 5), randn(Float32, 3, 10), "GRU"),
168+
# Pending https://github.com/FluxML/NNlib.jl/issues/565
168169
(ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"),
169170
]
170171

0 commit comments

Comments
 (0)