|
67 | 67 | @test Flux.Losses.mse(bias(ip), op) ≈ 4.f0
|
68 | 68 |
|
69 | 69 | @testset "Grouped Conv" begin
|
| 70 | + ip = rand(Float32, 28, 100, 2) |
| 71 | + c = Conv((3,), 100 => 25, groups = 5) |
| 72 | + @test size(c.weight) == (3, 20, 25) |
| 73 | + @test size(c(ip)) == (26, 25, 2) |
| 74 | + |
70 | 75 | ip = rand(Float32, 28, 28, 100, 2)
|
71 | 76 | c = Conv((3,3), 100 => 25, groups = 5)
|
72 | 77 | @test size(c.weight) == (3, 3, 20, 25)
|
73 | 78 | @test size(c(ip)) == (26, 26, 25, 2)
|
| 79 | + |
| 80 | + ip = rand(Float32, 10, 11, 12, 100, 2) |
| 81 | + c = Conv((3,4,5), 100 => 25, groups = 5) |
| 82 | + @test size(c.weight) == (3,4,5, 20, 25) |
| 83 | + @test size(c(ip)) == (8,8,8, 25, 2) |
74 | 84 | end
|
75 | 85 | end
|
76 | 86 |
|
|
129 | 139 | @test size(m1(x)) == size(m2(x))
|
130 | 140 | @test gradient(()->sum(m2(x)), params(m2)) isa Flux.Zygote.Grads
|
131 | 141 |
|
| 142 | + x = randn(Float32, 10, 2,1) |
| 143 | + m = ConvTranspose((3,), 2=>4, pad=SamePad(), groups=2) |
| 144 | + @test size(m(x)) === (10,4,1) |
| 145 | + @test length(m.weight) == (3)*(2*4) / 2 |
| 146 | + |
| 147 | + x = randn(Float32, 10, 11, 4,2) |
| 148 | + m = ConvTranspose((3,5), 4=>4, pad=SamePad(), groups=4) |
| 149 | + @test size(m(x)) === (10,11, 4,2) |
| 150 | + @test length(m.weight) == (3*5)*(4*4)/4 |
| 151 | + |
| 152 | + x = randn(Float32, 10, 11, 12, 3,2) |
| 153 | + m = ConvTranspose((3,5,3), 3=>6, pad=SamePad(), groups=3) |
| 154 | + @test size(m(x)) === (10,11, 12, 6,2) |
| 155 | + @test length(m.weight) == (3*5*3) * (3*6) / 3 |
| 156 | + |
132 | 157 | @test occursin("groups=2", sprint(show, ConvTranspose((3,3), 2=>4, groups=2)))
|
133 | 158 | @test occursin("2 => 4" , sprint(show, ConvTranspose((3,3), 2=>4, groups=2)))
|
134 | 159 | end
|
|
138 | 163 | w = rand(Float32, 2,2,1,1)
|
139 | 164 | y = CrossCor(w, [0.0])
|
140 | 165 |
|
141 |
| - @test sum(w .* x[1:2, 1:2, :, :]) ≈ y(x)[1, 1, 1, 1] rtol=1e-7 |
| 166 | + @test sum(w .* x[1:2, 1:2, :, :]) ≈ y(x)[1, 1, 1, 1] rtol=2e-7 |
142 | 167 |
|
143 | 168 | r = zeros(Float32, 28, 28, 1, 5)
|
144 | 169 | m = Chain(
|
|
0 commit comments