@@ -151,8 +151,6 @@ channels from `in` to `out`.
151
151
152
152
Accepts the keyword `init` (default: `glorot_uniform`) to control the sampling
153
153
distribution.
154
-
155
- See also: [`depthwiseconvfilter`](@ref)
156
154
"""
157
155
function convfilter (filter:: NTuple{N,Integer} , ch:: Pair{<:Integer,<:Integer} ;
158
156
init = glorot_uniform, groups = 1 ) where N
@@ -298,91 +296,37 @@ end
298
296
299
297
"""
300
298
DepthwiseConv(filter, in => out, σ=identity; stride=1, pad=0, dilation=1, [bias, init])
299
+ DepthwiseConv(weight::AbstractArray, [bias, activation; stride, pad, dilation])
300
+
301
+ Return a depthwise convolutional layer, that is a [`Conv`](@ref) layer with number of
302
+ groups equal to the number of input channels.
301
303
302
- Depthwise convolutional layer. `filter` is a tuple of integers
303
- specifying the size of the convolutional kernel, while
304
- `in` and `out` specify the number of input and output channels.
305
-
306
- Note that `out` must be an integer multiple of `in`.
307
-
308
- Parameters are controlled by additional keywords, with defaults
309
- `init=glorot_uniform` and `bias=true`.
310
-
311
- See also [`Conv`](@ref) for more detailed description of keywords.
304
+ See [`Conv`](@ref) for a description of the arguments.
312
305
313
306
# Examples
307
+
314
308
```jldoctest
315
309
julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of 50 RGB images
316
310
317
311
julia> lay = DepthwiseConv((5,5), 3 => 6, relu; bias=false)
318
- DepthwiseConv ((5, 5), 3 => 6, relu, bias=false) # 150 parameters
312
+ Conv ((5, 5), 3 => 6, relu, groups=3, bias=false) # 150 parameters
319
313
320
314
julia> lay(xs) |> size
321
315
(96, 96, 6, 50)
322
316
323
- julia> DepthwiseConv((5,5), 3 => 9, stride=2, pad=2)(xs) |> size
317
+ julia> DepthwiseConv((5, 5), 3 => 9, stride=2, pad=2)(xs) |> size
324
318
(50, 50, 9, 50)
325
319
```
326
320
"""
327
- struct DepthwiseConv{N,M,F,A,V}
328
- σ:: F
329
- weight:: A
330
- bias:: V
331
- stride:: NTuple{N,Int}
332
- pad:: NTuple{M,Int}
333
- dilation:: NTuple{N,Int}
321
+ function DepthwiseConv (k:: NTuple{<:Any,Integer} , ch:: Pair{<:Integer,<:Integer} , σ = identity;
322
+ stride = 1 , pad = 0 , dilation = 1 , bias = true , init = glorot_uniform)
323
+ Conv (k, ch, σ; groups= ch. first, stride, pad, dilation, bias, init)
334
324
end
335
325
336
- """
337
- DepthwiseConv(weight::AbstractArray, [bias, activation; stride, pad, dilation])
338
-
339
- Constructs a layer with the given weight and bias arrays.
340
- Accepts the same keywords as the `DepthwiseConv((4,4), 3 => 6, relu)` method.
341
- """
342
326
function DepthwiseConv (w:: AbstractArray{T,N} , bias = true , σ = identity;
343
- stride = 1 , pad = 0 , dilation = 1 ) where {T,N}
344
- stride = expand (Val (N- 2 ), stride)
345
- dilation = expand (Val (N- 2 ), dilation)
346
- pad = calc_padding (DepthwiseConv, pad, size (w)[1 : N- 2 ], dilation, stride)
347
- b = create_bias (w, bias, prod (size (w)[N- 1 : end ]))
348
- return DepthwiseConv (σ, w, b, stride, pad, dilation)
349
- end
350
-
351
- function DepthwiseConv (k:: NTuple{N,Integer} , ch:: Pair{<:Integer,<:Integer} , σ = identity;
352
- init = glorot_uniform, stride = 1 , pad = 0 , dilation = 1 ,
353
- bias = true ) where N
354
- @assert ch[2 ] % ch[1 ] == 0 " Output channels must be integer multiple of input channels"
355
- weight = depthwiseconvfilter (k, ch, init = init)
356
- return DepthwiseConv (weight, bias, σ; stride, pad, dilation)
357
- end
358
-
359
- @functor DepthwiseConv
360
-
361
- """
362
- depthwiseconvfilter(filter::Tuple, in => out)
363
-
364
- Constructs a depthwise convolutional weight array defined by `filter` and channels
365
- from `in` to `out`.
366
-
367
- Accepts the keyword `init` (default: `glorot_uniform`) to control the sampling
368
- distribution.
369
-
370
- See also: [`convfilter`](@ref)
371
- """
372
- depthwiseconvfilter (filter:: NTuple{N,Integer} , ch:: Pair{<:Integer,<:Integer} ;
373
- init = glorot_uniform) where N = init (filter... , div (ch[2 ], ch[1 ]), ch[1 ])
374
-
375
- function (c:: DepthwiseConv )(x)
376
- σ = NNlib. fast_act (c. σ, x)
377
- cdims = DepthwiseConvDims (x, c. weight; stride= c. stride, padding= c. pad, dilation= c. dilation)
378
- σ .(depthwiseconv (x, c. weight, cdims) .+ conv_reshape_bias (c))
379
- end
380
-
381
- function Base. show (io:: IO , l:: DepthwiseConv )
382
- print (io, " DepthwiseConv(" , size (l. weight)[1 : end - 2 ])
383
- print (io, " , " , size (l. weight)[end ], " => " , prod (size (l. weight)[end - 1 : end ]))
384
- _print_conv_opt (io, l)
385
- print (io, " )" )
327
+ stride = 1 , pad = 0 , dilation = 1 ) where {T,N}
328
+ w2 = reshape (w, size (w)[1 : end - 2 ]. .. , 1 , :)
329
+ Conv (w2, bias, σ; groups = size (w)[end - 1 ], stride, pad, dilation)
386
330
end
387
331
388
332
0 commit comments