Skip to content

Commit 4b7c980

Browse files
committed
add more conv groups tests
1 parent 6386c5d commit 4b7c980

File tree

1 file changed

+26
-1
lines changed

1 file changed

+26
-1
lines changed

test/layers/conv.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,20 @@ end
6767
@test Flux.Losses.mse(bias(ip), op) 4.f0
6868

6969
@testset "Grouped Conv" begin
70+
ip = rand(Float32, 28, 100, 2)
71+
c = Conv((3,), 100 => 25, groups = 5)
72+
@test size(c.weight) == (3, 20, 25)
73+
@test size(c(ip)) == (26, 25, 2)
74+
7075
ip = rand(Float32, 28, 28, 100, 2)
7176
c = Conv((3,3), 100 => 25, groups = 5)
7277
@test size(c.weight) == (3, 3, 20, 25)
7378
@test size(c(ip)) == (26, 26, 25, 2)
79+
80+
ip = rand(Float32, 10, 11, 12, 100, 2)
81+
c = Conv((3,4,5), 100 => 25, groups = 5)
82+
@test size(c.weight) == (3,4,5, 20, 25)
83+
@test size(c(ip)) == (8,8,8, 25, 2)
7484
end
7585
end
7686

@@ -129,6 +139,21 @@ end
129139
@test size(m1(x)) == size(m2(x))
130140
@test gradient(()->sum(m2(x)), params(m2)) isa Flux.Zygote.Grads
131141

142+
x = randn(Float32, 10, 2,1)
143+
m = ConvTranspose((3,), 2=>4, pad=SamePad(), groups=2)
144+
@test size(m(x)) === (10,4,1)
145+
@test length(m.weight) == (3)*(2*4) / 2
146+
147+
x = randn(Float32, 10, 11, 4,2)
148+
m = ConvTranspose((3,5), 4=>4, pad=SamePad(), groups=4)
149+
@test size(m(x)) === (10,11, 4,2)
150+
@test length(m.weight) == (3*5)*(4*4)/4
151+
152+
x = randn(Float32, 10, 11, 12, 3,2)
153+
m = ConvTranspose((3,5,3), 3=>6, pad=SamePad(), groups=3)
154+
@test size(m(x)) === (10,11, 12, 6,2)
155+
@test length(m.weight) == (3*5*3) * (3*6) / 3
156+
132157
@test occursin("groups=2", sprint(show, ConvTranspose((3,3), 2=>4, groups=2)))
133158
@test occursin("2 => 4" , sprint(show, ConvTranspose((3,3), 2=>4, groups=2)))
134159
end
@@ -138,7 +163,7 @@ end
138163
w = rand(Float32, 2,2,1,1)
139164
y = CrossCor(w, [0.0])
140165

141-
@test sum(w .* x[1:2, 1:2, :, :]) y(x)[1, 1, 1, 1] rtol=1e-7
166+
@test sum(w .* x[1:2, 1:2, :, :]) y(x)[1, 1, 1, 1] rtol=2e-7
142167

143168
r = zeros(Float32, 28, 28, 1, 5)
144169
m = Chain(

0 commit comments

Comments
 (0)