Skip to content

Commit a0b804a

Browse files
Merge pull request #1921 from FluxML/cl/dept
remove DepthwiseConv type in favor of Conv
2 parents 57beb23 + cab5f26 commit a0b804a

File tree

7 files changed

+26
-75
lines changed

7 files changed

+26
-75
lines changed

docs/src/models/layers.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ CrossCor
2525
SamePad
2626
Flux.flatten
2727
Flux.convfilter
28-
Flux.depthwiseconvfilter
2928
```
3029

3130
## Upsampling Layers

src/layers/conv.jl

Lines changed: 18 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ julia> Flux.params(c1) |> length
128128
"""
129129
function Conv(w::AbstractArray{T,N}, b = true, σ = identity;
130130
stride = 1, pad = 0, dilation = 1, groups = 1) where {T,N}
131+
132+
@assert size(w, N) % groups == 0 "Output channel dimension must be divisible by groups."
131133
stride = expand(Val(N-2), stride)
132134
dilation = expand(Val(N-2), dilation)
133135
pad = calc_padding(Conv, pad, size(w)[1:N-2], dilation, stride)
@@ -151,12 +153,12 @@ channels from `in` to `out`.
151153
152154
Accepts the keyword `init` (default: `glorot_uniform`) to control the sampling
153155
distribution.
154-
155-
See also: [`depthwiseconvfilter`](@ref)
156156
"""
157157
function convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
158158
init = glorot_uniform, groups = 1) where N
159159
cin, cout = ch
160+
@assert cin % groups == 0 "Input channel dimension must be divisible by groups."
161+
@assert cout % groups == 0 "Output channel dimension must be divisible by groups."
160162
init(filter..., cin÷groups, cout)
161163
end
162164

@@ -298,91 +300,37 @@ end
298300

299301
"""
300302
DepthwiseConv(filter, in => out, σ=identity; stride=1, pad=0, dilation=1, [bias, init])
303+
DepthwiseConv(weight::AbstractArray, [bias, activation; stride, pad, dilation])
304+
305+
Return a depthwise convolutional layer, that is a [`Conv`](@ref) layer with number of
306+
groups equal to the number of input channels.
301307
302-
Depthwise convolutional layer. `filter` is a tuple of integers
303-
specifying the size of the convolutional kernel, while
304-
`in` and `out` specify the number of input and output channels.
305-
306-
Note that `out` must be an integer multiple of `in`.
307-
308-
Parameters are controlled by additional keywords, with defaults
309-
`init=glorot_uniform` and `bias=true`.
310-
311-
See also [`Conv`](@ref) for more detailed description of keywords.
308+
See [`Conv`](@ref) for a description of the arguments.
312309
313310
# Examples
311+
314312
```jldoctest
315313
julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of 50 RGB images
316314
317315
julia> lay = DepthwiseConv((5,5), 3 => 6, relu; bias=false)
318-
DepthwiseConv((5, 5), 3 => 6, relu, bias=false) # 150 parameters
316+
Conv((5, 5), 3 => 6, relu, groups=3, bias=false) # 150 parameters
319317
320318
julia> lay(xs) |> size
321319
(96, 96, 6, 50)
322320
323-
julia> DepthwiseConv((5,5), 3 => 9, stride=2, pad=2)(xs) |> size
321+
julia> DepthwiseConv((5, 5), 3 => 9, stride=2, pad=2)(xs) |> size
324322
(50, 50, 9, 50)
325323
```
326324
"""
327-
struct DepthwiseConv{N,M,F,A,V}
328-
σ::F
329-
weight::A
330-
bias::V
331-
stride::NTuple{N,Int}
332-
pad::NTuple{M,Int}
333-
dilation::NTuple{N,Int}
325+
function DepthwiseConv(k::NTuple{<:Any,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
326+
stride = 1, pad = 0, dilation = 1, bias = true, init = glorot_uniform)
327+
Conv(k, ch, σ; groups=ch.first, stride, pad, dilation, bias, init)
334328
end
335329

336-
"""
337-
DepthwiseConv(weight::AbstractArray, [bias, activation; stride, pad, dilation])
338-
339-
Constructs a layer with the given weight and bias arrays.
340-
Accepts the same keywords as the `DepthwiseConv((4,4), 3 => 6, relu)` method.
341-
"""
342330
function DepthwiseConv(w::AbstractArray{T,N}, bias = true, σ = identity;
343-
stride = 1, pad = 0, dilation = 1) where {T,N}
344-
stride = expand(Val(N-2), stride)
345-
dilation = expand(Val(N-2), dilation)
346-
pad = calc_padding(DepthwiseConv, pad, size(w)[1:N-2], dilation, stride)
347-
b = create_bias(w, bias, prod(size(w)[N-1:end]))
348-
return DepthwiseConv(σ, w, b, stride, pad, dilation)
349-
end
350-
351-
function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
352-
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
353-
bias = true) where N
354-
@assert ch[2] % ch[1] == 0 "Output channels must be integer multiple of input channels"
355-
weight = depthwiseconvfilter(k, ch, init = init)
356-
return DepthwiseConv(weight, bias, σ; stride, pad, dilation)
357-
end
358-
359-
@functor DepthwiseConv
360-
361-
"""
362-
depthwiseconvfilter(filter::Tuple, in => out)
363-
364-
Constructs a depthwise convolutional weight array defined by `filter` and channels
365-
from `in` to `out`.
366-
367-
Accepts the keyword `init` (default: `glorot_uniform`) to control the sampling
368-
distribution.
369-
370-
See also: [`convfilter`](@ref)
371-
"""
372-
depthwiseconvfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
373-
init = glorot_uniform) where N = init(filter..., div(ch[2], ch[1]), ch[1])
374-
375-
function (c::DepthwiseConv)(x)
376-
σ = NNlib.fast_act(c.σ, x)
377-
cdims = DepthwiseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
378-
σ.(depthwiseconv(x, c.weight, cdims) .+ conv_reshape_bias(c))
379-
end
380-
381-
function Base.show(io::IO, l::DepthwiseConv)
382-
print(io, "DepthwiseConv(", size(l.weight)[1:end-2])
383-
print(io, ", ", size(l.weight)[end], " => ", prod(size(l.weight)[end-1:end]))
384-
_print_conv_opt(io, l)
385-
print(io, ")")
331+
stride = 1, pad = 0, dilation = 1) where {T,N}
332+
w2 = reshape(w, size(w)[1:end-2]..., 1, :)
333+
Conv(w2, bias, σ; groups = size(w)[end-1], stride, pad, dilation)
386334
end
387335

388336

src/layers/show.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ _show_children(m::Maxout) = m.layers
5555
_show_children(p::Parallel) = (p.connection, p.layers...)
5656

5757
for T in [
58-
:Conv, :ConvTranspose, :CrossCor, :DepthwiseConv, :Dense, :Bilinear, :Embedding,
58+
:Conv, :ConvTranspose, :CrossCor, :Dense, :Bilinear, :Embedding,
5959
:BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm,
6060
]
6161
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)

src/outputsize.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ end
153153

154154
## fixes for layers that don't work out of the box
155155

156-
for (fn, Dims) in ((:conv, DenseConvDims), (:depthwiseconv, DepthwiseConvDims))
156+
for (fn, Dims) in ((:conv, DenseConvDims),)
157157
@eval begin
158158
function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{Nil}, dims::$Dims)
159159
fill(nil, NNlib.output_size(dims)..., NNlib.channels_out(dims), size(a)[end])

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ Has the following behaviour
383383
384384
Some caveats:
385385
* Not all layers will be identity mapping when used with this init. Exceptions
386-
include recurrent layers, `DepthwiseConv` and normalization layers.
386+
include recurrent layers and normalization layers.
387387
388388
* Layers must have `input_size == output_size` for identity mapping to be
389389
possible. When this is not the case, extra dimensions of the array are padded with zeros.

test/cuda/layers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
end
1212

1313
# TODO: These layers get into scalar indexing issues.
14-
const BROKEN_LAYERS = Union{DepthwiseConv}
14+
const BROKEN_LAYERS = Union{}
1515

1616
const ACTIVATIONS = [identity, relu, tanh,
1717
sigmoid, exp, softplus,

test/layers/conv.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ end
8181
c = Conv((3,4,5), 100 => 25, groups = 5)
8282
@test size(c.weight) == (3,4,5, 20, 25)
8383
@test size(c(ip)) == (8,8,8, 25, 2)
84+
85+
# Test that we cannot ask for non-integer multiplication factors
86+
@test_throws AssertionError Conv((2, 2), 3=>10, groups=2)
87+
@test_throws AssertionError Conv((2, 2), 2=>9, groups=2)
8488
end
8589
end
8690

0 commit comments

Comments
 (0)