@@ -195,7 +195,7 @@ conv_dims(c::Conv, x::AbstractArray) =
195
195
ChainRulesCore. @non_differentiable conv_dims (:: Any , :: Any )
196
196
197
197
function (c:: Conv )(x:: AbstractArray )
198
- _size_check (c, x, ndims (x) - 1 => _channels_in (c) )
198
+ _conv_size_check (c, x)
199
199
σ = NNlib. fast_act (c. σ, x)
200
200
cdims = conv_dims (c, x)
201
201
xT = _match_eltype (c, x)
331
331
ChainRulesCore. @non_differentiable conv_transpose_dims (:: Any , :: Any )
332
332
333
333
function (c:: ConvTranspose )(x:: AbstractArray )
334
- _size_check (c, x, ndims (x) - 1 => _channels_in (c) )
334
+ _conv_size_check (c, x)
335
335
σ = NNlib. fast_act (c. σ, x)
336
336
cdims = conv_transpose_dims (c, x)
337
337
xT = _match_eltype (c, x)
@@ -473,7 +473,7 @@ crosscor_dims(c::CrossCor, x::AbstractArray) =
473
473
ChainRulesCore. @non_differentiable crosscor_dims (:: Any , :: Any )
474
474
475
475
function (c:: CrossCor )(x:: AbstractArray )
476
- _size_check (c, x, ndims (x) - 1 => _channels_in (c) )
476
+ _conv_size_check (c, x)
477
477
σ = NNlib. fast_act (c. σ, x)
478
478
cdims = crosscor_dims (c, x)
479
479
xT = _match_eltype (c, x)
@@ -487,6 +487,15 @@ function Base.show(io::IO, l::CrossCor)
487
487
print (io, " )" )
488
488
end
489
489
490
+ function _conv_size_check (layer, x:: AbstractArray )
491
+ ndims (x) == ndims (layer. weight) || throw (DimensionMismatch (LazyString (" layer " , layer,
492
+ " expects ndims(input) == " , ndims (layer. weight), " , but got " , summary (x))))
493
+ d = ndims (x)- 1
494
+ n = _channels_in (layer)
495
+ size (x,d) == n || throw (DimensionMismatch (LazyString (" layer " , layer,
496
+ lazy " expects size(input, $d) == $n, but got " , summary (x))))
497
+ end
498
+ ChainRulesCore. @non_differentiable _conv_size_check (:: Any , :: Any )
490
499
"""
491
500
AdaptiveMaxPool(out::NTuple)
492
501
@@ -515,6 +524,7 @@ struct AdaptiveMaxPool{S, O}
515
524
end
516
525
517
526
function (a:: AdaptiveMaxPool{S} )(x:: AbstractArray{T, S} ) where {S, T}
527
+ _pool_size_check (a, a. out, x)
518
528
insize = size (x)[1 : end - 2 ]
519
529
outsize = a. out
520
530
stride = insize .÷ outsize
@@ -556,6 +566,7 @@ struct AdaptiveMeanPool{S, O}
556
566
end
557
567
558
568
function (a:: AdaptiveMeanPool{S} )(x:: AbstractArray{T, S} ) where {S, T}
569
+ _pool_size_check (a, a. out, x)
559
570
insize = size (x)[1 : end - 2 ]
560
571
outsize = a. out
561
572
stride = insize .÷ outsize
@@ -694,6 +705,7 @@ function MaxPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
694
705
end
695
706
696
707
function (m:: MaxPool )(x)
708
+ _pool_size_check (m, m. k, x)
697
709
pdims = PoolDims (x, m. k; padding= m. pad, stride= m. stride)
698
710
return maxpool (x, pdims)
699
711
end
@@ -753,6 +765,7 @@ function MeanPool(k::NTuple{N,Integer}; pad = 0, stride = k) where N
753
765
end
754
766
755
767
function (m:: MeanPool )(x)
768
+ _pool_size_check (m, m. k, x)
756
769
pdims = PoolDims (x, m. k; padding= m. pad, stride= m. stride)
757
770
return meanpool (x, pdims)
758
771
end
@@ -763,3 +776,11 @@ function Base.show(io::IO, m::MeanPool)
763
776
m. stride == m. k || print (io, " , stride=" , _maybetuple_string (m. stride))
764
777
print (io, " )" )
765
778
end
779
+
780
+ function _pool_size_check (layer, tup:: Tuple , x:: AbstractArray )
781
+ N = length (tup) + 2
782
+ ndims (x) == N || throw (DimensionMismatch (LazyString (" layer " , layer,
783
+ " expects ndims(input) == " , N, " , but got " , summary (x))))
784
+ end
785
+ ChainRulesCore. @non_differentiable _pool_size_check (:: Any , :: Any )
786
+
0 commit comments