Skip to content

Commit f66be89

Browse files
bors[bot]jw3126DhairyaLGandhi
authored
Merge #1744
1744: allow groups in ConvTranspose r=DhairyaLGandhi a=jw3126 fix #1743 Co-authored-by: Jan Weidner <jw3126@gmail.com> Co-authored-by: Dhairya Gandhi <dhairya@juliacomputing.com>
2 parents f8ebead + ac3cdfa commit f66be89

File tree

4 files changed

+129
-31
lines changed

4 files changed

+129
-31
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Flux"
22
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
3-
version = "0.12.7"
3+
version = "0.12.8"
44

55
[deps]
66
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/layers/conv.jl

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ end
136136

137137
function Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
138138
init = glorot_uniform, stride = 1, pad = 0, dilation = 1, groups = 1,
139-
weight = convfilter(k, (ch[1] ÷ groups => ch[2]); init), bias = true) where N
139+
weight = convfilter(k, ch; init, groups), bias = true) where N
140140

141141
Conv(weight, bias, σ; stride, pad, dilation, groups)
142142
end
@@ -152,8 +152,11 @@ distribution.
152152
153153
See also: [`depthwiseconvfilter`](@ref)
154154
"""
155-
convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
156-
init = glorot_uniform) where N = init(filter..., ch...)
155+
function convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
156+
init = glorot_uniform, groups = 1) where N
157+
cin, cout = ch
158+
init(filter..., cin÷groups, cout)
159+
end
157160

158161
@functor Conv
159162

@@ -163,9 +166,12 @@ function (c::Conv)(x::AbstractArray)
163166
σ.(conv(x, c.weight, cdims) .+ b)
164167
end
165168

169+
_channels_in(l ::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups
170+
_channels_out(l::Conv) = size(l.weight, ndims(l.weight))
171+
166172
function Base.show(io::IO, l::Conv)
167173
print(io, "Conv(", size(l.weight)[1:ndims(l.weight)-2])
168-
print(io, ", ", size(l.weight, ndims(l.weight)-1), " => ", size(l.weight, ndims(l.weight)))
174+
print(io, ", ", _channels_in(l), " => ", _channels_out(l))
169175
_print_conv_opt(io, l)
170176
print(io, ")")
171177
end
@@ -175,7 +181,10 @@ function _print_conv_opt(io::IO, l)
175181
all(==(0), l.pad) || print(io, ", pad=", _maybetuple_string(l.pad))
176182
all(==(1), l.stride) || print(io, ", stride=", _maybetuple_string(l.stride))
177183
all(==(1), l.dilation) || print(io, ", dilation=", _maybetuple_string(l.dilation))
178-
l.bias == Zeros() && print(io, ", bias=false")
184+
if hasproperty(l, :groups)
185+
(l.groups == 1) || print(io, ", groups=", l.groups)
186+
end
187+
(l.bias isa Zeros) && print(io, ", bias=false")
179188
end
180189

181190
"""
@@ -216,44 +225,53 @@ struct ConvTranspose{N,M,F,A,V}
216225
stride::NTuple{N,Int}
217226
pad::NTuple{M,Int}
218227
dilation::NTuple{N,Int}
228+
groups::Int
219229
end
220230

231+
_channels_in(l::ConvTranspose) = size(l.weight)[end]
232+
_channels_out(l::ConvTranspose) = size(l.weight)[end-1]*l.groups
233+
221234
"""
222-
ConvTranspose(weight::AbstractArray, [bias, activation; stride, pad, dilation])
235+
ConvTranspose(weight::AbstractArray, [bias, activation; stride, pad, dilation, groups])
223236
224237
Constructs a layer with the given weight and bias arrays.
225238
Accepts the same keywords as the `ConvTranspose((4,4), 3 => 7, relu)` method.
226239
"""
227240
function ConvTranspose(w::AbstractArray{T,N}, bias = true, σ = identity;
228-
stride = 1, pad = 0, dilation = 1) where {T,N}
241+
stride = 1, pad = 0, dilation = 1, groups=1) where {T,N}
229242
stride = expand(Val(N-2), stride)
230243
dilation = expand(Val(N-2), dilation)
231244
pad = calc_padding(ConvTranspose, pad, size(w)[1:N-2], dilation, stride)
232-
b = create_bias(w, bias, size(w, N-1))
233-
return ConvTranspose(σ, w, b, stride, pad, dilation)
245+
b = create_bias(w, bias, size(w, N-1) * groups)
246+
return ConvTranspose(σ, w, b, stride, pad, dilation, groups)
234247
end
235248

236249
function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
237250
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
238-
weight = convfilter(k, reverse(ch), init = init), bias = true) where N
251+
groups = 1,
252+
weight = convfilter(k, reverse(ch); init, groups),
253+
bias = true,
254+
) where N
239255

240-
ConvTranspose(weight, bias, σ; stride, pad, dilation)
256+
ConvTranspose(weight, bias, σ; stride, pad, dilation, groups)
241257
end
242258

243259
@functor ConvTranspose
244260

245261
function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
246-
# Calculate size of "input", from ∇conv_data()'s perspective...
247-
combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end])
248-
I = (size(x)[1:end-2] .- 1).*c.stride .+ 1 .+ (size(c.weight)[1:end-2] .- 1).*c.dilation .- combined_pad
249-
C_in = size(c.weight)[end-1]
250-
batch_size = size(x)[end]
251-
# Create DenseConvDims() that looks like the corresponding conv()
252-
return DenseConvDims((I..., C_in, batch_size), size(c.weight);
253-
stride=c.stride,
254-
padding=c.pad,
255-
dilation=c.dilation,
256-
)
262+
# Calculate size of "input", from ∇conv_data()'s perspective...
263+
combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end])
264+
I = (size(x)[1:end-2] .- 1).*c.stride .+ 1 .+ (size(c.weight)[1:end-2] .- 1).*c.dilation .- combined_pad
265+
C_in = size(c.weight)[end-1] * c.groups
266+
batch_size = size(x)[end]
267+
# Create DenseConvDims() that looks like the corresponding conv()
268+
w_size = size(c.weight)
269+
return DenseConvDims((I..., C_in, batch_size), w_size;
270+
stride=c.stride,
271+
padding=c.pad,
272+
dilation=c.dilation,
273+
groups=c.groups,
274+
)
257275
end
258276

259277
# TODO: Find proper fix for https://github.com/FluxML/Flux.jl/issues/900
@@ -267,7 +285,7 @@ end
267285

268286
function Base.show(io::IO, l::ConvTranspose)
269287
print(io, "ConvTranspose(", size(l.weight)[1:ndims(l.weight)-2])
270-
print(io, ", ", size(l.weight, ndims(l.weight)), " => ", size(l.weight, ndims(l.weight)-1))
288+
print(io, ", ", _channels_in(l), " => ", _channels_out(l))
271289
_print_conv_opt(io, l)
272290
print(io, ")")
273291
end

test/cuda/layers.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,17 @@ function gpu_gradtest(name::String, layers::Vector, x_cpu = nothing, args...; te
4848
xg_cpu = gradient(x -> sum(l_cpu(x)), x_cpu)[1]
4949
xg_gpu = gradient(x -> sum(l_gpu(x)), x_gpu)[1]
5050

51-
# test
51+
# test
5252
if test_cpu
5353
@test y_gpu y_cpu rtol=1f-3 atol=1f-3
5454
if isnothing(xg_cpu)
5555
@test isnothing(xg_gpu)
5656
else
57-
@test Array(xg_gpu) xg_cpu rtol=1f-3 atol=1f-3
57+
if layer === GroupedConvTranspose
58+
@test Array(xg_gpu) xg_cpu rtol=2f-2 atol=1f-3
59+
else
60+
@test Array(xg_gpu) xg_cpu rtol=1f-3 atol=1f-3
61+
end
5862
end
5963
end
6064
@test gs_gpu isa Flux.Zygote.Grads
@@ -80,6 +84,7 @@ ConvTransposeNoBias(args...) = ConvTranspose(args...; bias = false)
8084
CrossCorNoBias(args...) = CrossCor(args...; bias = false)
8185
DepthwiseConvNoBias(args...) = DepthwiseConv(args...; bias = false)
8286
GroupedConv(args...) = Conv(args..., groups = 5)
87+
GroupedConvTranspose(args...) = ConvTranspose(args..., groups = 5)
8388

8489
for act in ACTIVATIONS
8590
r = rand(Float32, 28, 28, 1, 1)
@@ -89,16 +94,16 @@ for act in ACTIVATIONS
8994
DepthwiseConv, DepthwiseConvNoBias]
9095
gpu_gradtest("Convolution with $act", conv_layers, r, (2,2), 1=>3, act, test_cpu = false)
9196

92-
groupedconv = [GroupedConv]
97+
groupedconv = [GroupedConv, GroupedConvTranspose]
9398
gpu_gradtest("GroupedConvolution with $act", groupedconv, rand(Float32, 28, 28, 100, 2), (3,3), 100 => 25, act, test_cpu = true)
9499

95100
batch_norm = [BatchNorm]
96101
gpu_gradtest("BatchNorm 1 with $act", batch_norm, rand(Float32, 28,28,3,4), 3, act, test_cpu = false) #TODO fix errors
97102
gpu_gradtest("BatchNorm 2 with $act", batch_norm, rand(Float32, 5,4), 5, act, test_cpu = false)
98-
103+
99104
instancenorm = [InstanceNorm]
100105
gpu_gradtest("InstanceNorm with $act", instancenorm, r, 1, act, test_cpu = false)
101-
106+
102107
groupnorm = [GroupNorm]
103108
gpu_gradtest("GroupNorm with $act", groupnorm, rand(Float32, 28,28,3,1), 3, 1, act, test_cpu = false)
104109
end
@@ -151,7 +156,7 @@ end
151156
else
152157
@test sum(l(ip)) 0.f0
153158
gs = gradient(() -> sum(l(ip)), Flux.params(l))
154-
@test l.bias gs.params
159+
@test l.bias gs.params
155160
end
156161
end
157162

test/layers/conv.jl

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,60 @@ end
6767
@test Flux.Losses.mse(bias(ip), op) 4.f0
6868

6969
@testset "Grouped Conv" begin
70+
ip = rand(Float32, 28, 100, 2)
71+
c = Conv((3,), 100 => 25, groups = 5)
72+
@test size(c.weight) == (3, 20, 25)
73+
@test size(c(ip)) == (26, 25, 2)
74+
7075
ip = rand(Float32, 28, 28, 100, 2)
7176
c = Conv((3,3), 100 => 25, groups = 5)
7277
@test size(c.weight) == (3, 3, 20, 25)
7378
@test size(c(ip)) == (26, 26, 25, 2)
79+
80+
ip = rand(Float32, 10, 11, 12, 100, 2)
81+
c = Conv((3,4,5), 100 => 25, groups = 5)
82+
@test size(c.weight) == (3,4,5, 20, 25)
83+
@test size(c(ip)) == (8,8,8, 25, 2)
7484
end
7585
end
7686

87+
@testset "_channels_in, _channels_out" begin
88+
_channels_in = Flux._channels_in
89+
_channels_out = Flux._channels_out
90+
@test _channels_in(Conv((3,) , 2=>4)) == 2
91+
@test _channels_in(Conv((5,6,) , 2=>4)) == 2
92+
@test _channels_in(Conv((1,2,3), 2=>4)) == 2
93+
@test _channels_out(Conv((3,) , 2=>4)) == 4
94+
@test _channels_out(Conv((5,6,) , 2=>4)) == 4
95+
@test _channels_out(Conv((1,2,3), 2=>4)) == 4
96+
97+
@test _channels_in( ConvTranspose((3,) , 1=>4)) == 1
98+
@test _channels_in( ConvTranspose((5,6,) , 2=>4)) == 2
99+
@test _channels_in( ConvTranspose((1,2,3), 3=>4)) == 3
100+
@test _channels_out(ConvTranspose((3,) , 2=>1)) == 1
101+
@test _channels_out(ConvTranspose((5,6,) , 2=>2)) == 2
102+
@test _channels_out(ConvTranspose((1,2,3), 2=>3)) == 3
103+
104+
@test _channels_in( ConvTranspose((6,) , 8=>4, groups=4)) == 8
105+
@test _channels_in( ConvTranspose((5,6,) , 2=>4, groups=2)) == 2
106+
@test _channels_in( ConvTranspose((1,2,3), 3=>6, groups=3)) == 3
107+
108+
@test _channels_out(ConvTranspose((1,) , 10=>15, groups=5)) == 15
109+
@test _channels_out(ConvTranspose((3,2) , 10=>15, groups=5)) == 15
110+
@test _channels_out(ConvTranspose((5,6,) , 2=>2, groups=2)) == 2
111+
112+
for Layer in [Conv, ConvTranspose]
113+
for _ in 1:10
114+
groups = rand(1:10)
115+
kernel_size = Tuple(rand(1:5) for _ in rand(1:3))
116+
cin = rand(1:5) * groups
117+
cout = rand(1:5) * groups
118+
@test _channels_in(Layer(kernel_size, cin=>cout; groups)) == cin
119+
@test _channels_out(Layer(kernel_size, cin=>cout; groups)) == cout
120+
end
121+
end
122+
end
123+
77124
@testset "asymmetric padding" begin
78125
r = ones(Float32, 28, 28, 1, 1)
79126
m = Conv((3, 3), 1=>1, relu; pad=(0,1,1,2))
@@ -118,14 +165,42 @@ end
118165
x = zeros(Float32, 5, 5, 2, 4)
119166
m = ConvTranspose((3,3), 2=>3)
120167
@test gradient(()->sum(m(x)), params(m)) isa Flux.Zygote.Grads
168+
169+
# test ConvTranspose supports groups argument
170+
x = randn(Float32, 10, 10, 2, 3)
171+
m1 = ConvTranspose((3,3), 2=>4, pad=SamePad())
172+
@test size(m1.weight) == (3,3,4,2)
173+
@test size(m1(x)) == (10,10,4,3)
174+
m2 = ConvTranspose((3,3), 2=>4, groups=2, pad=SamePad())
175+
@test size(m2.weight) == (3,3,2,2)
176+
@test size(m1(x)) == size(m2(x))
177+
@test gradient(()->sum(m2(x)), params(m2)) isa Flux.Zygote.Grads
178+
179+
x = randn(Float32, 10, 2,1)
180+
m = ConvTranspose((3,), 2=>4, pad=SamePad(), groups=2)
181+
@test size(m(x)) === (10,4,1)
182+
@test length(m.weight) == (3)*(2*4) / 2
183+
184+
x = randn(Float32, 10, 11, 4,2)
185+
m = ConvTranspose((3,5), 4=>4, pad=SamePad(), groups=4)
186+
@test size(m(x)) === (10,11, 4,2)
187+
@test length(m.weight) == (3*5)*(4*4)/4
188+
189+
x = randn(Float32, 10, 11, 12, 3,2)
190+
m = ConvTranspose((3,5,3), 3=>6, pad=SamePad(), groups=3)
191+
@test size(m(x)) === (10,11, 12, 6,2)
192+
@test length(m.weight) == (3*5*3) * (3*6) / 3
193+
194+
@test occursin("groups=2", sprint(show, ConvTranspose((3,3), 2=>4, groups=2)))
195+
@test occursin("2 => 4" , sprint(show, ConvTranspose((3,3), 2=>4, groups=2)))
121196
end
122197

123198
@testset "CrossCor" begin
124199
x = rand(Float32, 28, 28, 1, 1)
125200
w = rand(Float32, 2,2,1,1)
126201
y = CrossCor(w, [0.0])
127202

128-
@test sum(w .* x[1:2, 1:2, :, :]) y(x)[1, 1, 1, 1] rtol=1e-7
203+
@test sum(w .* x[1:2, 1:2, :, :]) y(x)[1, 1, 1, 1] rtol=2e-7
129204

130205
r = zeros(Float32, 28, 28, 1, 5)
131206
m = Chain(

0 commit comments

Comments
 (0)