Skip to content

Commit 6ed69c8

Browse files
committed
allow groups in ConvTranspose
1 parent 7c04652 commit 6ed69c8

File tree

2 files changed

+45
-14
lines changed

2 files changed

+45
-14
lines changed

src/layers/conv.jl

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ end
136136

137137
function Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
138138
init = glorot_uniform, stride = 1, pad = 0, dilation = 1, groups = 1,
139-
weight = convfilter(k, (ch[1] ÷ groups => ch[2]); init), bias = true) where N
139+
weight = convfilter(k, ch; init, groups), bias = true) where N
140140

141141
Conv(weight, bias, σ; stride, pad, dilation, groups)
142142
end
@@ -152,8 +152,11 @@ distribution.
152152
153153
See also: [`depthwiseconvfilter`](@ref)
154154
"""
155-
convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
156-
init = glorot_uniform) where N = init(filter..., ch...)
155+
function convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
156+
init = glorot_uniform, groups=1) where N
157+
cin, cout = ch
158+
init(filter..., cin÷groups, cout)
159+
end
157160

158161
@functor Conv
159162

@@ -163,9 +166,12 @@ function (c::Conv)(x::AbstractArray)
163166
σ.(conv(x, c.weight, cdims) .+ b)
164167
end
165168

169+
channels_in(l ::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups
170+
channels_out(l::Conv) = size(l.weight, ndims(l.weight))
171+
166172
function Base.show(io::IO, l::Conv)
167173
print(io, "Conv(", size(l.weight)[1:ndims(l.weight)-2])
168-
print(io, ", ", size(l.weight, ndims(l.weight)-1), " => ", size(l.weight, ndims(l.weight)))
174+
print(io, ", ", channels_in(l), " => ", channels_out(l))
169175
_print_conv_opt(io, l)
170176
print(io, ")")
171177
end
@@ -175,7 +181,10 @@ function _print_conv_opt(io::IO, l)
175181
all(==(0), l.pad) || print(io, ", pad=", _maybetuple_string(l.pad))
176182
all(==(1), l.stride) || print(io, ", stride=", _maybetuple_string(l.stride))
177183
all(==(1), l.dilation) || print(io, ", dilation=", _maybetuple_string(l.dilation))
178-
l.bias == Zeros() && print(io, ", bias=false")
184+
if hasproperty(l, :groups)
185+
(l.groups == 1) || print(io, ", groups=", l.groups)
186+
end
187+
(l.bias isa Zeros) && print(io, ", bias=false")
179188
end
180189

181190
"""
@@ -216,28 +225,35 @@ struct ConvTranspose{N,M,F,A,V}
216225
stride::NTuple{N,Int}
217226
pad::NTuple{M,Int}
218227
dilation::NTuple{N,Int}
228+
groups::Int
219229
end
220230

231+
channels_in(l::ConvTranspose) = size(l.weight)[end]
232+
channels_out(l::ConvTranspose) = size(l.weight)[end-1]*l.groups
233+
221234
"""
222-
ConvTranspose(weight::AbstractArray, [bias, activation; stride, pad, dilation])
235+
ConvTranspose(weight::AbstractArray, [bias, activation; stride, pad, dilation, groups])
223236
224237
Constructs a layer with the given weight and bias arrays.
225238
Accepts the same keywords as the `ConvTranspose((4,4), 3 => 7, relu)` method.
226239
"""
227240
function ConvTranspose(w::AbstractArray{T,N}, bias = true, σ = identity;
228-
stride = 1, pad = 0, dilation = 1) where {T,N}
241+
stride = 1, pad = 0, dilation = 1, groups=1) where {T,N}
229242
stride = expand(Val(N-2), stride)
230243
dilation = expand(Val(N-2), dilation)
231244
pad = calc_padding(ConvTranspose, pad, size(w)[1:N-2], dilation, stride)
232-
b = create_bias(w, bias, size(w, N-1))
233-
return ConvTranspose(σ, w, b, stride, pad, dilation)
245+
b = create_bias(w, bias, size(w, N-1) * groups)
246+
return ConvTranspose(σ, w, b, stride, pad, dilation, groups)
234247
end
235248

236249
function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
237250
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
238-
weight = convfilter(k, reverse(ch), init = init), bias = true) where N
251+
groups=1,
252+
weight = convfilter(k, ch[2]÷groups=>ch[1], init = init),
253+
bias = true,
254+
) where N
239255

240-
ConvTranspose(weight, bias, σ; stride, pad, dilation)
256+
ConvTranspose(weight, bias, σ; stride, pad, dilation, groups)
241257
end
242258

243259
@functor ConvTranspose
@@ -246,13 +262,15 @@ function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
246262
# Calculate size of "input", from ∇conv_data()'s perspective...
247263
combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end])
248264
I = (size(x)[1:end-2] .- 1).*c.stride .+ 1 .+ (size(c.weight)[1:end-2] .- 1).*c.dilation .- combined_pad
249-
C_in = size(c.weight)[end-1]
265+
C_in = size(c.weight)[end-1] * c.groups
250266
batch_size = size(x)[end]
251267
# Create DenseConvDims() that looks like the corresponding conv()
252-
return DenseConvDims((I..., C_in, batch_size), size(c.weight);
268+
w_size = size(c.weight)
269+
return DenseConvDims((I..., C_in, batch_size), w_size;
253270
stride=c.stride,
254271
padding=c.pad,
255272
dilation=c.dilation,
273+
groups=c.groups,
256274
)
257275
end
258276

@@ -267,7 +285,7 @@ end
267285

268286
function Base.show(io::IO, l::ConvTranspose)
269287
print(io, "ConvTranspose(", size(l.weight)[1:ndims(l.weight)-2])
270-
print(io, ", ", size(l.weight, ndims(l.weight)), " => ", size(l.weight, ndims(l.weight)-1))
288+
print(io, ", ", channels_in(l), " => ", channels_out(l))
271289
_print_conv_opt(io, l)
272290
print(io, ")")
273291
end

test/layers/conv.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,19 @@ end
118118
x = zeros(Float32, 5, 5, 2, 4)
119119
m = ConvTranspose((3,3), 2=>3)
120120
@test gradient(()->sum(m(x)), params(m)) isa Flux.Zygote.Grads
121+
122+
# test ConvTranspose supports groups argument
123+
x = randn(Float32, 10, 10, 2, 3)
124+
m1 = ConvTranspose((3,3), 2=>4, pad=SamePad())
125+
@test size(m1.weight) == (3,3,4,2)
126+
@test size(m1(x)) == (10,10,4,3)
127+
m2 = ConvTranspose((3,3), 2=>4, groups=2, pad=SamePad())
128+
@test size(m2.weight) == (3,3,2,2)
129+
@test size(m1(x)) == size(m2(x))
130+
@test gradient(()->sum(m2(x)), params(m2)) isa Flux.Zygote.Grads
131+
132+
@test occursin("groups=2", sprint(show, ConvTranspose((3,3), 2=>4, groups=2)))
133+
@test occursin("2 => 4" , sprint(show, ConvTranspose((3,3), 2=>4, groups=2)))
121134
end
122135

123136
@testset "CrossCor" begin

0 commit comments

Comments
 (0)