Skip to content

Commit cab5f26

Browse files
assertion num channels compatible with groups
1 parent 450cb2e commit cab5f26

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

src/layers/conv.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ julia> Flux.params(c1) |> length
128128
"""
129129
function Conv(w::AbstractArray{T,N}, b = true, σ = identity;
130130
stride = 1, pad = 0, dilation = 1, groups = 1) where {T,N}
131+
132+
@assert size(w, N) % groups == 0 "Output channel dimension must be divisible by groups."
131133
stride = expand(Val(N-2), stride)
132134
dilation = expand(Val(N-2), dilation)
133135
pad = calc_padding(Conv, pad, size(w)[1:N-2], dilation, stride)
@@ -155,6 +157,8 @@ distribution.
155157
function convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
156158
init = glorot_uniform, groups = 1) where N
157159
cin, cout = ch
160+
@assert cin % groups == 0 "Input channel dimension must be divisible by groups."
161+
@assert cout % groups == 0 "Output channel dimension must be divisible by groups."
158162
init(filter..., cin÷groups, cout)
159163
end
160164

test/layers/conv.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ end
8181
c = Conv((3,4,5), 100 => 25, groups = 5)
8282
@test size(c.weight) == (3,4,5, 20, 25)
8383
@test size(c(ip)) == (8,8,8, 25, 2)
84+
85+
# Test that we cannot ask for non-integer multiplication factors
86+
@test_throws AssertionError Conv((2, 2), 3=>10, groups=2)
87+
@test_throws AssertionError Conv((2, 2), 2=>9, groups=2)
8488
end
8589
end
8690

0 commit comments

Comments
 (0)