Skip to content

Commit 305b305

Browse files
fix #484 -- broaden pooling method signatures (#485)
* broaden pooling method signatures * pooling: revert type predition for target array * add ReverseDiff as test dependency * Update test/pooling.jl reformat whitespace in tests Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com> * Update test/pooling.jl remove unused type parameters from function signatures Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com> * Update src/impl/pooling_direct.jl remove unused type parameters from function signatures Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com> * remove unused type parameters in pooling methods --------- Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
1 parent 6824c6d commit 305b305

File tree

5 files changed

+97
-41
lines changed

5 files changed

+97
-41
lines changed

src/impl/pooling_direct.jl

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
# the inner loop operation and a few initialization parameters.
33
for name in (:max, :mean, :lpnorm)
44
@eval function $((Symbol("$(name)pool_direct!")))(
5-
y::AbstractArray{T, 5}, x::AbstractArray{T, 5},
6-
pdims::PoolDims; alpha::T=T(1), beta::T=T(0), kwargs...) where T
5+
y::AbstractArray{<:Any, 5}, x::AbstractArray{<:Any, 5},
6+
pdims::PoolDims; alpha=1, beta=0, kwargs...)
77
$((Symbol("$(name)pool_direct!")))(
88
y, x, pdims,
99
Val(kernel_size(pdims)), Val(channels_out(pdims)),
@@ -13,13 +13,13 @@ for name in (:max, :mean, :lpnorm)
1313
end
1414

1515
@eval function $((Symbol("$(name)pool_direct!")))(
16-
y::AbstractArray{T,5}, x::AbstractArray{T,5},
16+
y::AbstractArray{T,5}, x::AbstractArray{<:Any,5},
1717
pdims::PoolDims,
1818
# kernel size, channels out, padding, dilation, stride
1919
::Val{K}, ::Val{C}, ::Val{P}, ::Val{D}, ::Val{S};
20-
alpha::T=T(1), beta::T=T(0), kwargs...
20+
alpha=1, beta=0, kwargs...
2121
) where {T, K, C, P, D, S}
22-
@assert beta == T(0) "beta not supported yet"
22+
@assert iszero(beta) "beta not supported yet"
2323
check_dims(size(x), size(y), pdims)
2424

2525
width, height, depth = input_size(pdims)
@@ -36,10 +36,21 @@ for name in (:max, :mean, :lpnorm)
3636
@inline project(idx, stride, pad) = (idx - 1) * stride - pad + 1
3737

3838
# If we're doing mean pooling, we represent division by kernel size by rolling it
39-
# into the `alpha` multiplier.
40-
if $(name == :mean)
41-
alpha = alpha / prod(K)
39+
# into the `alpha` multiplier.
40+
# The type might change here, that's why we prepend the underscore
41+
# (does it make a difference, though?)
42+
_alpha = if $(name == :mean)
43+
T(alpha / prod(K))
44+
else
45+
T(alpha)
4246
end
47+
# _beta = T(beta)
48+
49+
# A quick note on the array element types `T` and `R`:
50+
# Ideally, `T == R`, but in some edge-cases, this might not be the case
51+
# (e.g. with `ReverseDiff.TrackedArray`, see issue #484).
52+
# If the types differ, we will initialize variables (like `_alpha` above) with the
53+
# target eltype `T`.
4354

4455
p = if $(name != :lpnorm) 0 else
4556
!haskey(kwargs, :p) && error("lpnormpool needs keyword argument `p`")
@@ -94,7 +105,7 @@ for name in (:max, :mean, :lpnorm)
94105
# for lpnormpool, y = (∑ᵢ xᵢ^p)^(1 / p)
95106
m = $(name == :lpnorm) ? m^(T(1) / p) : m
96107

97-
y[w, h, d, c, batch_idx] = alpha * m # + beta * y[w, h, d, c, batch_idx]
108+
y[w, h, d, c, batch_idx] = _alpha * m # + _beta * y[w, h, d, c, batch_idx]
98109
end
99110
end
100111
end
@@ -148,7 +159,7 @@ for name in (:max, :mean, :lpnorm)
148159
end
149160
end
150161
$(name == :lpnorm) && (m = m^(T(1) / p))
151-
y[w, h, d, c, batch_idx] = alpha * m # + beta * y[w, h, d, c, batch_idx]
162+
y[w, h, d, c, batch_idx] = _alpha * m # + _beta * y[w, h, d, c, batch_idx]
152163
end
153164
end
154165
end
@@ -159,9 +170,9 @@ for name in (:max, :mean, :lpnorm)
159170
end
160171

161172
@eval function $((Symbol("$(name)pool_direct!")))(
162-
dx::AbstractArray{T,5}, dy::AbstractArray{T,5},
163-
y::AbstractArray{T,5}, x::AbstractArray{T,5},
164-
pdims::PoolDims; kwargs...) where T
173+
dx::AbstractArray{<:Any,5}, dy::AbstractArray{<:Any,5},
174+
y::AbstractArray{<:Any,5}, x::AbstractArray{<:Any,5},
175+
pdims::PoolDims; kwargs...)
165176
$((Symbol("$(name)pool_direct!")))(
166177
dx, dy, y, x, pdims, Val(kernel_size(pdims)); kwargs...)
167178
return dx
@@ -170,10 +181,10 @@ for name in (:max, :mean, :lpnorm)
170181
# Same story for gradients, and although this is very similar to the forward pass,
171182
# it's unfortunately different enough that I think we need a separate function. :(
172183
@eval function $((Symbol("$(name)pool_direct!")))(
173-
dx::AbstractArray{T,5}, dy::AbstractArray{T,5},
174-
y::AbstractArray{T,5}, x::AbstractArray{T,5},
184+
dx::AbstractArray{T,5}, dy::AbstractArray{<:Any,5},
185+
y::AbstractArray{<:Any,5}, x::AbstractArray{<:Any,5},
175186
pdims::PoolDims, ::Val{K}; # == kernel_size(pdims)
176-
alpha::T=T(1), beta::T=T(0), kwargs...) where {T, K}
187+
alpha=1, beta=0, kwargs...) where {T, K}
177188
check_dims(size(x), size(dy), pdims)
178189

179190
width, height, depth = input_size(pdims)
@@ -183,6 +194,10 @@ for name in (:max, :mean, :lpnorm)
183194
dil_w, dil_h, dil_d = dilation(pdims)
184195
stride_w, stride_h, stride_d = stride(pdims)
185196

197+
# Concerning array eltypes `DX, DY, X, Y`, we want handle them like above, i.e.,
198+
# initialize everything with the left-hand-side type (target type).
199+
# Of course, ideally the types are all the same anyways.
200+
186201
# We use calc_padding_regions to split outselves up into separate regions that
187202
# may or may not need to worry about padding:
188203
padded_regions, central_region = calc_padding_regions(pdims)
@@ -191,9 +206,11 @@ for name in (:max, :mean, :lpnorm)
191206
@inline project(idx, stride, pad) = (idx - 1) * stride - pad + 1
192207

193208
# If we're doing mean pooling, we represent division by kernel size by rolling
194-
# it into the `alpha` multiplier.
195-
if $(name == :mean)
196-
alpha = alpha / prod(K)
209+
# it into the `_alpha` multiplier.
210+
_alpha = if $(name == :mean)
211+
T(alpha / prod(K))
212+
else
213+
T(alpha)
197214
end
198215

199216
p = if $(name != :lpnorm) 0 else
@@ -236,15 +253,15 @@ for name in (:max, :mean, :lpnorm)
236253
# Uncomment line below if using with non-precise output (e.g. by NNPACK)
237254
# if abs(y_idx - x[x_idxs...]) < 1e-5 && !maxpool_already_chose
238255
if y_idx x[input_kw, input_kh, input_kd, c, batch_idx]
239-
dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * alpha #+ beta * dx[x_idxs...]
256+
dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * _alpha #+ _beta * dx[x_idxs...]
240257
maxpool_already_chose = true
241258
# Maxpooling does not support `beta` right now. :(
242259
# else
243260
# dx[x_idxs...] = T(0) + beta*dx[x_idxs...]
244261
end
245262
elseif $(name == :mean)
246263
# Either does meanpool :(
247-
dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * alpha
264+
dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * _alpha
248265
elseif $(name == :lpnorm)
249266
# y = (∑ᵢ xᵢ^p)^(1 / p), ∂y/∂xᵢ = xᵢ^(p-1) × y^(1-p)
250267
grad = x[input_kw, input_kh, input_kd, c, batch_idx]^(p-1) * y_idx^(1-p)
@@ -302,13 +319,13 @@ for name in (:max, :mean, :lpnorm)
302319
# Uncomment line below if using with non-precise output
303320
# if abs(y_idx - x[x_idxs...]) < 1e-5 && !maxpool_already_chose
304321
if y_idx x[input_kw, input_kh, input_kd, c, batch_idx]
305-
dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * alpha #+ beta * dx[x_idxs...]
322+
dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * _alpha #+ _beta * dx[x_idxs...]
306323
maxpool_already_chose = true
307324
# else
308325
# dx[x_idxs...] = T(0) + beta*dx[x_idxs...]
309326
end
310327
elseif $(name == :mean)
311-
dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * alpha #+ beta * dx[x_idxs...]
328+
dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * _alpha #+ _beta * dx[x_idxs...]
312329
elseif $(name == :lpnorm)
313330
grad = x[input_kw, input_kh, input_kd, c, batch_idx]^(p-1) * y_idx^(1-p)
314331
dx[input_kw, input_kh, input_kd, c, batch_idx] += dy_idx * grad

src/pooling.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ for (front_name, backend) in (
3636
# We only define 3d pooling primitives, we reshape lower down to get 1d and 2d pooling
3737
@eval begin
3838
function $(Symbol("$(front_name)!"))(
39-
y::AbstractArray{T,5}, x::AbstractArray{T,5},
40-
pdims::PoolDims; kwargs...) where {T}
39+
y::AbstractArray{<:Any,5}, x::AbstractArray{<:Any,5},
40+
pdims::PoolDims; kwargs...)
4141
$(Symbol("$(front_name)_$(backend)!"))(y, x, pdims; kwargs...)
4242
end
4343
end
@@ -51,9 +51,9 @@ for (front_name, backend) in (
5151
)
5252
@eval begin
5353
function $(Symbol("$(front_name)!"))(
54-
dx::AbstractArray{T,5}, dy::AbstractArray{T,5},
55-
y::AbstractArray{T,5}, x::AbstractArray{T,5},
56-
pdims::PoolDims; kwargs...) where {T}
54+
dx::AbstractArray{<:Any,5}, dy::AbstractArray{<:Any,5},
55+
y::AbstractArray{<:Any,5}, x::AbstractArray{<:Any,5},
56+
pdims::PoolDims; kwargs...)
5757
$(Symbol("$(front_name)_$(backend)!"))(dx, dy, y, x, pdims; kwargs...)
5858
end
5959
end
@@ -68,8 +68,8 @@ for front_name in (:maxpool, :meanpool, :lpnormpool)
6868
for N in (3, 4)
6969
@eval begin
7070
function $(Symbol("$(front_name)$(backend)!"))(
71-
y::AbstractArray{T,$N}, x::AbstractArray{T,$N},
72-
pdims::PoolDims; kwargs...) where {T}
71+
y::AbstractArray{<:Any,$N}, x::AbstractArray{<:Any,$N},
72+
pdims::PoolDims; kwargs...)
7373
$(Symbol("$(front_name)$(backend)!"))(
7474
insert_singleton_spatial_dimension(y, $(5 - N)),
7575
insert_singleton_spatial_dimension(x, $(5 - N)),
@@ -84,9 +84,9 @@ for front_name in (:maxpool, :meanpool, :lpnormpool)
8484

8585
# backprops too
8686
function $(Symbol("$(front_name)$(backend)!"))(
87-
dx::AbstractArray{T,$N}, dy::AbstractArray{T,$N},
88-
y::AbstractArray{T,$N}, x::AbstractArray{T,$N},
89-
pdims::PoolDims; kwargs...) where {T}
87+
dx::AbstractArray{<:Any,$N}, dy::AbstractArray{<:Any,$N},
88+
y::AbstractArray{<:Any,$N}, x::AbstractArray{<:Any,$N},
89+
pdims::PoolDims; kwargs...)
9090
$(Symbol("$(front_name)$(backend)!"))(
9191
insert_singleton_spatial_dimension(dx, $(5 - N)),
9292
insert_singleton_spatial_dimension(dy, $(5 - N)),
@@ -112,20 +112,20 @@ for backend in (Symbol(), :_direct, :_nnpack)
112112
for name in (:maxpool, :meanpool, :lpnormpool)
113113
@eval begin
114114
function $(Symbol("$(name)$(backend)"))(
115-
x::AbstractArray{xT,N},
116-
pdims::PoolDims; kwargs...) where {xT, N}
115+
x::AbstractArray{<:Any,N},
116+
pdims::PoolDims; kwargs...) where {N}
117117
y = similar(x, output_size(pdims)..., channels_out(pdims), size(x, N))
118-
fill!(y, xT(0))
118+
fill!(y, 0)
119119
return $(Symbol("$(name)$(backend)!"))(y, x, pdims; kwargs...)
120120
end
121121

122122
# Backprops too
123123
function $(Symbol("$(name)$(backend)"))(
124-
dy::AbstractArray{T,N}, y::AbstractArray{T,N},
125-
x::AbstractArray{T,N}, pdims::PoolDims;
126-
kwargs...) where {T, N}
124+
dy::AbstractArray{<:Any,N}, y::AbstractArray{<:Any,N},
125+
x::AbstractArray{<:Any,N}, pdims::PoolDims;
126+
kwargs...) where {N}
127127
dx = similar(x, input_size(pdims)..., channels_in(pdims), size(dy, N))
128-
fill!(dx, T(0))
128+
fill!(dx, 0)
129129
return $(Symbol("$(name)$(backend)!"))(dx, dy, y, x, pdims; kwargs...)
130130
end
131131
end

test/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1212
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
14+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1415
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1516
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1617
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1718
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
1819
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
19-
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
20+
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

test/pooling.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,42 @@ maxpool_answer_nature = Dict(
905905
-0.25, -0.5, -0.25], (3, 3, 1, 1))
906906
@test all(pool .== valid)
907907

908+
# issue #484
909+
# Description: some in-place pooling functions only accepted arrays with the same eltype.
910+
# The strict method signatures were based on assumption on the return type of `similar`.
911+
# For ReverseDiff, this caused problems, e.g. with taking derivatives of pooling
912+
# operations.
913+
# Now, if explicitly calling an in-place pooling functions, a different `yT` is allowed.
914+
for xT in (Int32, Int64, Float16, Float32, Float64, BigFloat)
915+
for (xsz, psz) in ( # test a few different data and kernel sizes
916+
((1,1), (1,1)),
917+
((1,2), (1,1)), ((1,2), (1,2)),
918+
((2,1), (1,1)), ((2,1), (2,1)),
919+
((2,2), (1,1)), ((2,2), (1,2)), ((2,2), (2,1)),
920+
)
921+
x = ones(xT, xsz..., 1, 1)
922+
pdims = PoolDims(x, psz)
923+
for yT in (Float16, Float32, Float64, BigFloat)
924+
# `yT` is the target eltype and we do not test integer types here
925+
# because those cannot always store the pooling results.
926+
y = similar(x, yT, NNlib.output_size(pdims)..., NNlib.channels_out(pdims), size(x, 4))
927+
@test maxpool!(y, x, pdims) isa Array{yT}
928+
@test meanpool!(y, x, pdims) isa Array{yT}
929+
@test lpnormpool!(y, x, pdims; p=2) isa Array{yT}
930+
@test lpnormpool!(y, x, pdims; p=1.0) isa Array{yT}
931+
end
932+
end
933+
end
934+
935+
# This is how to test #484 with ReverseDiff:
936+
x = reshape(Float32[ 1 2; 3 4 ], (2,2,1,1))
937+
@test only(maxpool(x, (2,2))) == 4
938+
# define typemin, because of https://github.com/JuliaDiff/ReverseDiff.jl/issues/225
939+
Base.typemin(tr::Type{<:T}) where{V, T<:RD.TrackedReal{V, <:Any, <:Any}} = T(typemin(V))
940+
@test RD.gradient(_x -> only(maxpool(_x,(2,2))), x)[:,:,1,1] == [0 0; 0 1]
941+
@test only(meanpool(x, (2,2))) == 2.5
942+
@test all(==(0.25), RD.gradient(_x -> only(meanpool(_x,(2,2))), x))
943+
908944
# if NNlib.is_nnpack_available()
909945
# if NNlib.nnpack_supported_operation(pdims1)
910946
# @test NNlib.maxpool_nnpack(x, pdims1) isa Array{Float32, 4}

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ using StableRNGs
99
using Documenter
1010
using Adapt
1111
using KernelAbstractions
12+
import ReverseDiff as RD # used in `pooling.jl`
13+
1214
DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib, UnicodePlots); recursive=true)
1315

1416
ENV["NNLIB_TEST_CUDA"] = true # uncomment to run CUDA tests

0 commit comments

Comments
 (0)