136
136
137
137
function Conv (k:: NTuple{N,Integer} , ch:: Pair{<:Integer,<:Integer} , σ = identity;
138
138
init = glorot_uniform, stride = 1 , pad = 0 , dilation = 1 , groups = 1 ,
139
- weight = convfilter (k, (ch[ 1 ] ÷ groups => ch[ 2 ]) ; init), bias = true ) where N
139
+ weight = convfilter (k, ch ; init, groups ), bias = true ) where N
140
140
141
141
Conv (weight, bias, σ; stride, pad, dilation, groups)
142
142
end
@@ -152,8 +152,11 @@ distribution.
152
152
153
153
See also: [`depthwiseconvfilter`](@ref)
154
154
"""
155
- convfilter (filter:: NTuple{N,Integer} , ch:: Pair{<:Integer,<:Integer} ;
156
- init = glorot_uniform) where N = init (filter... , ch... )
155
+ function convfilter (filter:: NTuple{N,Integer} , ch:: Pair{<:Integer,<:Integer} ;
156
+ init = glorot_uniform, groups= 1 ) where N
157
+ cin, cout = ch
158
+ init (filter... , cin÷ groups, cout)
159
+ end
157
160
158
161
@functor Conv
159
162
@@ -163,9 +166,12 @@ function (c::Conv)(x::AbstractArray)
163
166
σ .(conv (x, c. weight, cdims) .+ b)
164
167
end
165
168
169
+ channels_in (l :: Conv ) = size (l. weight, ndims (l. weight)- 1 ) * l. groups
170
+ channels_out (l:: Conv ) = size (l. weight, ndims (l. weight))
171
+
166
172
function Base. show (io:: IO , l:: Conv )
167
173
print (io, " Conv(" , size (l. weight)[1 : ndims (l. weight)- 2 ])
168
- print (io, " , " , size (l . weight, ndims (l . weight) - 1 ) , " => " , size (l . weight, ndims (l . weight) ))
174
+ print (io, " , " , channels_in (l) , " => " , channels_out (l ))
169
175
_print_conv_opt (io, l)
170
176
print (io, " )" )
171
177
end
@@ -175,7 +181,10 @@ function _print_conv_opt(io::IO, l)
175
181
all (== (0 ), l. pad) || print (io, " , pad=" , _maybetuple_string (l. pad))
176
182
all (== (1 ), l. stride) || print (io, " , stride=" , _maybetuple_string (l. stride))
177
183
all (== (1 ), l. dilation) || print (io, " , dilation=" , _maybetuple_string (l. dilation))
178
- l. bias == Zeros () && print (io, " , bias=false" )
184
+ if hasproperty (l, :groups )
185
+ (l. groups == 1 ) || print (io, " , groups=" , l. groups)
186
+ end
187
+ (l. bias isa Zeros) && print (io, " , bias=false" )
179
188
end
180
189
181
190
"""
@@ -216,28 +225,35 @@ struct ConvTranspose{N,M,F,A,V}
216
225
stride:: NTuple{N,Int}
217
226
pad:: NTuple{M,Int}
218
227
dilation:: NTuple{N,Int}
228
+ groups:: Int
219
229
end
220
230
231
+ channels_in (l:: ConvTranspose ) = size (l. weight)[end ]
232
+ channels_out (l:: ConvTranspose ) = size (l. weight)[end - 1 ]* l. groups
233
+
221
234
"""
222
- ConvTranspose(weight::AbstractArray, [bias, activation; stride, pad, dilation])
235
+ ConvTranspose(weight::AbstractArray, [bias, activation; stride, pad, dilation, groups ])
223
236
224
237
Constructs a layer with the given weight and bias arrays.
225
238
Accepts the same keywords as the `ConvTranspose((4,4), 3 => 7, relu)` method.
226
239
"""
227
240
function ConvTranspose (w:: AbstractArray{T,N} , bias = true , σ = identity;
228
- stride = 1 , pad = 0 , dilation = 1 ) where {T,N}
241
+ stride = 1 , pad = 0 , dilation = 1 , groups = 1 ) where {T,N}
229
242
stride = expand (Val (N- 2 ), stride)
230
243
dilation = expand (Val (N- 2 ), dilation)
231
244
pad = calc_padding (ConvTranspose, pad, size (w)[1 : N- 2 ], dilation, stride)
232
- b = create_bias (w, bias, size (w, N- 1 ))
233
- return ConvTranspose (σ, w, b, stride, pad, dilation)
245
+ b = create_bias (w, bias, size (w, N- 1 ) * groups )
246
+ return ConvTranspose (σ, w, b, stride, pad, dilation, groups )
234
247
end
235
248
236
249
function ConvTranspose (k:: NTuple{N,Integer} , ch:: Pair{<:Integer,<:Integer} , σ = identity;
237
250
init = glorot_uniform, stride = 1 , pad = 0 , dilation = 1 ,
238
- weight = convfilter (k, reverse (ch), init = init), bias = true ) where N
251
+ groups= 1 ,
252
+ weight = convfilter (k, ch[2 ]÷ groups=> ch[1 ], init = init),
253
+ bias = true ,
254
+ ) where N
239
255
240
- ConvTranspose (weight, bias, σ; stride, pad, dilation)
256
+ ConvTranspose (weight, bias, σ; stride, pad, dilation, groups )
241
257
end
242
258
243
259
@functor ConvTranspose
@@ -246,13 +262,15 @@ function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
246
262
# Calculate size of "input", from ∇conv_data()'s perspective...
247
263
combined_pad = (c. pad[1 : 2 : end ] .+ c. pad[2 : 2 : end ])
248
264
I = (size (x)[1 : end - 2 ] .- 1 ). * c. stride .+ 1 .+ (size (c. weight)[1 : end - 2 ] .- 1 ). * c. dilation .- combined_pad
249
- C_in = size (c. weight)[end - 1 ]
265
+ C_in = size (c. weight)[end - 1 ] * c . groups
250
266
batch_size = size (x)[end ]
251
267
# Create DenseConvDims() that looks like the corresponding conv()
252
- return DenseConvDims ((I... , C_in, batch_size), size (c. weight);
268
+ w_size = size (c. weight)
269
+ return DenseConvDims ((I... , C_in, batch_size), w_size;
253
270
stride= c. stride,
254
271
padding= c. pad,
255
272
dilation= c. dilation,
273
+ groups= c. groups,
256
274
)
257
275
end
258
276
267
285
268
286
function Base. show (io:: IO , l:: ConvTranspose )
269
287
print (io, " ConvTranspose(" , size (l. weight)[1 : ndims (l. weight)- 2 ])
270
- print (io, " , " , size (l . weight, ndims (l . weight)) , " => " , size (l . weight, ndims (l . weight) - 1 ))
288
+ print (io, " , " , channels_in (l) , " => " , channels_out (l ))
271
289
_print_conv_opt (io, l)
272
290
print (io, " )" )
273
291
end
0 commit comments