Skip to content

Commit c4a0ee4

Browse files
authored
Improve errors for conv layers (#2404)
* better size check for conv layers * similar for pooling layers * change to DimensionMismatch
1 parent edc1d8c commit c4a0ee4

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

src/layers/basic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ function _size_check(layer, x::AbstractArray, (d, n)::Pair)
193193
d > 0 || throw(DimensionMismatch(string("layer ", layer,
194194
" expects ndims(input) > ", ndims(x)-d, ", but got ", summary(x))))
195195
size(x, d) == n || throw(DimensionMismatch(string("layer ", layer,
196-
" expects size(input, $d) == $n, but got ", summary(x))))
196+
lazy" expects size(input, $d) == $n, but got ", summary(x))))
197197
end
198198
ChainRulesCore.@non_differentiable _size_check(::Any...)
199199

src/layers/conv.jl

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ conv_dims(c::Conv, x::AbstractArray) =
195195
ChainRulesCore.@non_differentiable conv_dims(::Any, ::Any)
196196

197197
function (c::Conv)(x::AbstractArray)
198-
_size_check(c, x, ndims(x)-1 => _channels_in(c))
198+
_conv_size_check(c, x)
199199
σ = NNlib.fast_act(c.σ, x)
200200
cdims = conv_dims(c, x)
201201
xT = _match_eltype(c, x)
@@ -331,7 +331,7 @@ end
331331
ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any)
332332

333333
function (c::ConvTranspose)(x::AbstractArray)
334-
_size_check(c, x, ndims(x)-1 => _channels_in(c))
334+
_conv_size_check(c, x)
335335
σ = NNlib.fast_act(c.σ, x)
336336
cdims = conv_transpose_dims(c, x)
337337
xT = _match_eltype(c, x)
@@ -473,7 +473,7 @@ crosscor_dims(c::CrossCor, x::AbstractArray) =
473473
ChainRulesCore.@non_differentiable crosscor_dims(::Any, ::Any)
474474

475475
function (c::CrossCor)(x::AbstractArray)
476-
_size_check(c, x, ndims(x)-1 => _channels_in(c))
476+
_conv_size_check(c, x)
477477
σ = NNlib.fast_act(c.σ, x)
478478
cdims = crosscor_dims(c, x)
479479
xT = _match_eltype(c, x)
@@ -487,6 +487,15 @@ function Base.show(io::IO, l::CrossCor)
487487
print(io, ")")
488488
end
489489

490+
function _conv_size_check(layer, x::AbstractArray)
491+
ndims(x) == ndims(layer.weight) || throw(DimensionMismatch(LazyString("layer ", layer,
492+
" expects ndims(input) == ", ndims(layer.weight), ", but got ", summary(x))))
493+
d = ndims(x)-1
494+
n = _channels_in(layer)
495+
size(x,d) == n || throw(DimensionMismatch(LazyString("layer ", layer,
496+
lazy" expects size(input, $d) == $n, but got ", summary(x))))
497+
end
498+
ChainRulesCore.@non_differentiable _conv_size_check(::Any, ::Any)
490499
"""
491500
AdaptiveMaxPool(out::NTuple)
492501
@@ -515,6 +524,7 @@ struct AdaptiveMaxPool{S, O}
515524
end
516525

517526
function (a::AdaptiveMaxPool{S})(x::AbstractArray{T, S}) where {S, T}
527+
_pool_size_check(a, a.out, x)
518528
insize = size(x)[1:end-2]
519529
outsize = a.out
520530
stride = insize outsize
@@ -556,6 +566,7 @@ struct AdaptiveMeanPool{S, O}
556566
end
557567

558568
function (a::AdaptiveMeanPool{S})(x::AbstractArray{T, S}) where {S, T}
569+
_pool_size_check(a, a.out, x)
559570
insize = size(x)[1:end-2]
560571
outsize = a.out
561572
stride = insize outsize
@@ -694,6 +705,7 @@ function MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
694705
end
695706

696707
function (m::MaxPool)(x)
708+
_pool_size_check(m, m.k, x)
697709
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
698710
return maxpool(x, pdims)
699711
end
@@ -753,6 +765,7 @@ function MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
753765
end
754766

755767
function (m::MeanPool)(x)
768+
_pool_size_check(m, m.k, x)
756769
pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride)
757770
return meanpool(x, pdims)
758771
end
@@ -763,3 +776,11 @@ function Base.show(io::IO, m::MeanPool)
763776
m.stride == m.k || print(io, ", stride=", _maybetuple_string(m.stride))
764777
print(io, ")")
765778
end
779+
780+
function _pool_size_check(layer, tup::Tuple, x::AbstractArray)
781+
N = length(tup) + 2
782+
ndims(x) == N || throw(DimensionMismatch(LazyString("layer ", layer,
783+
" expects ndims(input) == ", N, ", but got ", summary(x))))
784+
end
785+
ChainRulesCore.@non_differentiable _pool_size_check(::Any, ::Any)
786+

0 commit comments

Comments
 (0)