@@ -48,7 +48,7 @@ function gpu_gradtest(name::String, layers::Vector, x_cpu = nothing, args...; te
48
48
xg_cpu = gradient (x -> sum (l_cpu (x)), x_cpu)[1 ]
49
49
xg_gpu = gradient (x -> sum (l_gpu (x)), x_gpu)[1 ]
50
50
51
- # test
51
+ # test
52
52
if test_cpu
53
53
@test y_gpu ≈ y_cpu rtol= 1f-3 atol= 1f-3
54
54
if isnothing (xg_cpu)
@@ -80,6 +80,7 @@ ConvTransposeNoBias(args...) = ConvTranspose(args...; bias = false)
80
80
CrossCorNoBias (args... ) = CrossCor (args... ; bias = false )
81
81
DepthwiseConvNoBias (args... ) = DepthwiseConv (args... ; bias = false )
82
82
GroupedConv (args... ) = Conv (args... , groups = 5 )
83
+ GroupedConvTranspose (args... ) = ConvTranspose (args... , groups = 5 )
83
84
84
85
for act in ACTIVATIONS
85
86
r = rand (Float32, 28 , 28 , 1 , 1 )
@@ -89,16 +90,16 @@ for act in ACTIVATIONS
89
90
DepthwiseConv, DepthwiseConvNoBias]
90
91
gpu_gradtest (" Convolution with $act " , conv_layers, r, (2 ,2 ), 1 => 3 , act, test_cpu = false )
91
92
92
- groupedconv = [GroupedConv]
93
+ groupedconv = [GroupedConv, GroupedConvTranspose ]
93
94
gpu_gradtest (" GroupedConvolution with $act " , groupedconv, rand (Float32, 28 , 28 , 100 , 2 ), (3 ,3 ), 100 => 25 , act, test_cpu = true )
94
95
95
96
batch_norm = [BatchNorm]
96
97
gpu_gradtest (" BatchNorm 1 with $act " , batch_norm, rand (Float32, 28 ,28 ,3 ,4 ), 3 , act, test_cpu = false ) # TODO fix errors
97
98
gpu_gradtest (" BatchNorm 2 with $act " , batch_norm, rand (Float32, 5 ,4 ), 5 , act, test_cpu = false )
98
-
99
+
99
100
instancenorm = [InstanceNorm]
100
101
gpu_gradtest (" InstanceNorm with $act " , instancenorm, r, 1 , act, test_cpu = false )
101
-
102
+
102
103
groupnorm = [GroupNorm]
103
104
gpu_gradtest (" GroupNorm with $act " , groupnorm, rand (Float32, 28 ,28 ,3 ,1 ), 3 , 1 , act, test_cpu = false )
104
105
end
151
152
else
152
153
@test sum (l (ip)) ≈ 0.f0
153
154
gs = gradient (() -> sum (l (ip)), Flux. params (l))
154
- @test l. bias ∉ gs. params
155
+ @test l. bias ∉ gs. params
155
156
end
156
157
end
157
158
0 commit comments