@@ -6,10 +6,6 @@ using Functors
6
6
using FiniteDifferences
7
7
using CUDA
8
8
9
- Enzyme. API. typeWarning! (false ) # suppresses a warning with Bilinear https://github.com/EnzymeAD/Enzyme.jl/issues/1341
10
- Enzyme. API. runtimeActivity! (true ) # for Enzyme debugging
11
- # Enzyme.Compiler.bitcode_replacement!(false)
12
-
13
9
_make_zero (x:: Union{Number,AbstractArray} ) = zero (x)
14
10
_make_zero (x) = x
15
11
make_zero (model) = fmap (_make_zero, model)
121
117
(SkipConnection (Dense (2 => 2 ), vcat), randn (Float32, 2 , 3 ), " SkipConnection" ),
122
118
(Flux. Bilinear ((2 , 2 ) => 3 ), randn (Float32, 2 , 1 ), " Bilinear" ),
123
119
(GRU (3 => 5 ), randn (Float32, 3 , 10 ), " GRU" ),
120
+ (ConvTranspose ((3 , 3 ), 3 => 2 , stride= 2 ), rand (Float32, 5 , 5 , 3 , 1 ), " ConvTranspose" ),
124
121
]
125
122
126
123
for (model, x, name) in models_xs
155
152
end
156
153
end
157
154
end
158
-
159
- @testset " Broken Models" begin
160
- function loss (model, x)
161
- Flux. reset! (model)
162
- sum (model (x))
163
- end
164
-
165
- device = Flux. get_device ()
166
-
167
- models_xs = [
168
- # Pending https://github.com/FluxML/NNlib.jl/issues/565
169
- (ConvTranspose ((3 , 3 ), 3 => 2 , stride= 2 ), rand (Float32, 5 , 5 , 3 , 1 ), " ConvTranspose" ),
170
- ]
171
-
172
- for (model, x, name) in models_xs
173
- @testset " check grad $name " begin
174
- println (" testing $name " )
175
- broken = false
176
- try
177
- test_enzyme_grad (loss, model, x)
178
- catch e
179
- println (e)
180
- broken = true
181
- end
182
- @test broken
183
- end
184
- end
185
- end
186
-
0 commit comments