Skip to content

Commit da35dcc

Browse files
committed
add tests for _channels_in, _channels_out
1 parent d41d6ea commit da35dcc

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

test/layers/conv.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,43 @@ end
8484
end
8585
end
8686

87+
@testset "_channels_in, _channels_out" begin
88+
_channels_in = Flux._channels_in
89+
_channels_out = Flux._channels_out
90+
@test _channels_in(Conv((3,) , 2=>4)) == 2
91+
@test _channels_in(Conv((5,6,) , 2=>4)) == 2
92+
@test _channels_in(Conv((1,2,3), 2=>4)) == 2
93+
@test _channels_out(Conv((3,) , 2=>4)) == 4
94+
@test _channels_out(Conv((5,6,) , 2=>4)) == 4
95+
@test _channels_out(Conv((1,2,3), 2=>4)) == 4
96+
97+
@test _channels_in( ConvTranspose((3,) , 1=>4)) == 1
98+
@test _channels_in( ConvTranspose((5,6,) , 2=>4)) == 2
99+
@test _channels_in( ConvTranspose((1,2,3), 3=>4)) == 3
100+
@test _channels_out(ConvTranspose((3,) , 2=>1)) == 1
101+
@test _channels_out(ConvTranspose((5,6,) , 2=>2)) == 2
102+
@test _channels_out(ConvTranspose((1,2,3), 2=>3)) == 3
103+
104+
@test _channels_in( ConvTranspose((6,) , 8=>4, groups=4)) == 8
105+
@test _channels_in( ConvTranspose((5,6,) , 2=>4, groups=2)) == 2
106+
@test _channels_in( ConvTranspose((1,2,3), 3=>6, groups=3)) == 3
107+
108+
@test _channels_out(ConvTranspose((1,) , 10=>15, groups=5)) == 15
109+
@test _channels_out(ConvTranspose((3,2) , 10=>15, groups=5)) == 15
110+
@test _channels_out(ConvTranspose((5,6,) , 2=>2, groups=2)) == 2
111+
112+
for Layer in [Conv, ConvTranspose]
113+
for _ in 1:10
114+
groups = rand(1:10)
115+
kernel_size = Tuple(rand(1:5) for _ in rand(1:3))
116+
cin = rand(1:5) * groups
117+
cout = rand(1:5) * groups
118+
@test _channels_in(Layer(kernel_size, cin=>cout; groups)) == cin
119+
@test _channels_out(Layer(kernel_size, cin=>cout; groups)) == cout
120+
end
121+
end
122+
end
123+
87124
@testset "asymmetric padding" begin
88125
r = ones(Float32, 28, 28, 1, 1)
89126
m = Conv((3, 3), 1=>1, relu; pad=(0,1,1,2))

0 commit comments

Comments
 (0)