diff --git a/src/upsample.jl b/src/upsample.jl index 5126441a2..a67c156aa 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -64,235 +64,166 @@ function ChainRulesCore.rrule(::typeof(upsample_nearest), x::AbstractArray, s::T return Ω, upsample_nearest_pullback end -""" - upsample_bilinear(x::AbstractArray{<:Number,4}, k::NTuple{2,Int}) - -Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `k`, -using bilinear interpolation. - -The size of the output is equal to -`(k[1]*S1, k[2]*S2, S3, S4)`, where `S1, S2, S3, S4 = size(x)`. - -The interpolation grid is identical to the one used by `imresize` from `Images.jl`. - -Currently only 2d upsampling is supported. -""" -function upsample_bilinear(x::AbstractArray{T,4}, k::NTuple{2,Int}) where T - # This function is gpu friendly - - imgsize = size(x) - newsize = get_newsize(imgsize, k) - - # Get linear interpolation lower- and upper index, and weights - ilow1, ihigh1, wdiff1 = get_inds_and_ws(x, imgsize[1], newsize[1], 1) - ilow2, ihigh2, wdiff2 = get_inds_and_ws(x, imgsize[2], newsize[2], 2) - - # Adjust the upper interpolation indices of the second dimension - ihigh2_r = adjoint_of_idx(ilow2)[ihigh2] - - @inbounds y = @view(x[ilow1,ilow2,:,:]) .* (1 .- wdiff1) .+ @view(x[ihigh1,ilow2,:,:]) .* wdiff1 - @inbounds y .= y .* (1 .- wdiff2) .+ y[:,ihigh2_r,:,:] .* wdiff2 - # @inbounds y = y .* (1 .- wdiff2) .+ @view(y[:,ihigh2_r,:,:]) .* wdiff2 # equivalent to line above - return y -end - -function get_inds_and_ws(x::T, n::Int, m::Int, dim::Int) where T <: AbstractArray - # Creates interpolation grid for resampling. - # Creates the same grid as used in Image.jl `imresize`. - step = n // m - offset = (n + 1)//2 - step//2 - step * (m//2 - 1) - xq = clamp.(range(offset, step=step, length=m), 1, n) - - # Creates interpolation lower and upper indices, and broadcastable weights - ilow = floor.(Int, xq) - ihigh = ceil.(Int, xq) - sizew = ntuple(i-> i == dim ? length(xq) : 1, ndims(x)) - wdiff = convert(T, reshape(xq .- ilow, sizew)) # wdiff possibly lives on gpu - return ilow, ihigh, wdiff +# utility function +@inline function compute_source_index_and_lambda( + ratio, # 0 < ratio < 1 + output_index, + input_size, + output_size +) + real_input_index = ratio*output_index + input_index0 = floor(Int, real_input_index) # typecast to int was here in C++ + offset = (input_index0 < input_size - 1) ? 1 : 0 + input_index1 = input_index0 + offset + lambda1 = real_input_index - input_index0 + lambda0 = 1 - lambda1 + return input_index0, input_index1, lambda0, lambda1 end """ - adjoint_of_idx(idx::Vector{<:Integer}) - -# Arguments -- `idx`: a vector of indices from which you want the adjoint. + upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}) + upsample_bilinear(x::AbstractArray{T,4}; size::NTuple{2,Integer}) -# Outputs --`idx_adjoint`: index that inverses the operation `x[idx]`. +Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `scale`, +using bilinear interpolation. As an alternative to using `scale`, the resulting image `size` +can be directly specified with a keyword argument. -# Explanation -Determines the adjoint of the vector of indices `idx`, based on the following assumptions: -* `idx[1] == 1` -* `all(d in [0,1] for d in diff(idx))` -The adjoint of `idx` can be seen as an inverse operation such that: +The size of the output is equal to +`(scale[1]*S1, scale[2]*S2, S3, S4)`, where `S1, S2, S3, S4 = size(x)`. +Examples: ```julia -x = [1, 2, 3, 4, 5] -idx = [1, 2, 2, 3, 4, 4, 5] -idx_adjoint = adjoint_of_idx(idx) -@assert x[idx][idx_adjoint] == x +upsample_bilinear(x, (2, pi)) # real scaling factors are allowed +upsample_bilinear(x; size=(64,64)) # specify ouput size ``` -The above holds as long as `idx` contains every index in `x`. """ -function adjoint_of_idx(idx::Vector{Int}) - d = trues(length(idx)) - d[2:end] .= diff(idx) - idx_adjoint = findall(d) - return idx_adjoint -end - -function get_newsize(sz, k) - return ntuple(i -> i <= length(k) ? sz[i]*k[i] : sz[i], length(sz)) +function upsample_bilinear(x::AbstractArray{<:Any,4}, scale::NTuple{2,Real}) + outsize = ntuple(i -> floor(Int, scale[i] * Base.size(x, i)), 2) + return upsample_bilinear(x; size=outsize) end +upsample_bilinear(x, scale::Real) = upsample_bilinear(x, (scale,scale)) -""" - ∇upsample_bilinear(Δ::AbstractArray{<:Number,4}, k::NTuple{2,Int}) - -# Arguments -- `Δ`: array that has been upsampled using the upsample factors in `k` - -# Outputs -- `dx`: downsampled version of `Δ` - -# Explanation - -Custom adjoint for [`upsample_bilinear`](@ref). -The adjoint of upsampling is a downsampling operation, which -in this implementation is performed using `NNlib.conv` in combination with a downsampling kernel based on the -upsampling factors. Because of the zero-padding during convolution, the values at the boundary are polluted by edge-effects, -which have been corrected for manually. -""" -function ∇upsample_bilinear(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int}) - # This function is gpu friendly - - # Be more efficient on some corner cases - if size(Δ, 1) == k[1] - Δ = sum(Δ, dims=1) - k = (1, k[2]) - end - if size(Δ, 2) == k[2] - Δ = sum(Δ, dims=2) - k = (k[1], 1) +function upsample_bilinear(x::AbstractArray{T,4}; size::NTuple{2,Integer}) where T + w,h,c,n = Base.size(x) + if (w,h) == size + return x end - if (size(Δ, 1) == 1) && (size(Δ, 2) == 1) - dx = Δ - return dx - end - - n_chan, n_batch = size(Δ, 3), size(Δ, 4) - - kern1 = get_downsamplekernel(Δ, k[1]) - kern2 = get_downsamplekernel(Δ, k[2]) - kern = kern1 * kern2' - - pad = (floor(Int, k[1]//2), floor(Int, k[2]//2)) - stride = k - - weight = similar(Δ, eltype(Δ), (size(kern)..., n_chan, n_chan)) - weight .= 0 - for i in 1:n_chan - weight[:,:,i,i] .= kern - end - # weight = cat(fill(kern, n_chan)..., dims=(3,4)) # slow - dx = conv(Δ, weight, pad=pad, stride=stride) + y = similar(x, T, size..., c, n) + return upsample_bilinear_whcn!(y, x) +end - # Still have to fix edge effects due to zero-padding of convolution, - # TODO: Could be circumvented by having padding that just extrapolates the value at the first/last index - # nextras = tuple((Int.(floor(factor//2)) for factor in k)...) - nextras = (floor(Int, k[1]//2), floor(Int, k[2]//2)) +function upsample_bilinear(x::AbstractArray{T,4}; size::NTuple{2,Integer}) where T<:Integer + y = float.(x) + res = upsample_bilinear(y; size=size) + return round.(T, res) +end - # First dimension edge-effect correction - if nextras[1] > 0 - kern1 = kern[1:nextras[1],:] - pad1 = (0, pad[2]) - stride1 = (1, stride[2]) - weight1 = similar(Δ, eltype(Δ), (size(kern1)..., n_chan, n_chan)) - weight1 .= 0 - for i in 1:n_chan - weight1[:,:,i,i] .= kern1 +# this is the core function which works on arrays of arbitrary size +# the implementation is a translation of https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/UpSampleMoreKernel.cpp +# which implements open-cv style linear interpolation / upsampling +# for simplicity, corners are aligned and all logic for other behaviour has been stripped +# - whcn because there is also a cwhn implementation +# - the function is parallelized using @threads +# - RGB types could be supported via reinterpreting +# - integer types need to be converted to Float and back +# - rationals work, but are slow +function upsample_bilinear_whcn!(output::AbstractArray{T,4}, input::AbstractArray{T,4}) where T + size(input)[3:4] == size(output)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") + in_w, in_h, channels, batches = size(input) + # treat batch and channel dimension as one for better parallelization granularity + channels *= batches + out_w, out_h, _, _ = size(output) + output_slice_size = out_h * out_w + + # T() and // so that we can handle rationals (super slow) + width_scale = T((in_w - 1) // (out_w - 1)) + height_scale = T((in_h - 1) // (out_h - 1)) + + @inline idx(c, h, w) = c * in_h * in_w + h * in_w + w + 1 + + @inbounds Threads.@threads for c in 0:channels-1 + for oh in 0:out_h-1 + ih0, ih1, h0lambda, h1lambda = compute_source_index_and_lambda(height_scale, oh, in_h, out_h) + for ow in 0:out_w-1 + iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w) + output_offset = c * output_slice_size + oh * out_w + ow + 1 + output[output_offset] = + (h0lambda * w0lambda * input[idx(c, ih0, iw0)] + # h0 * w0 * i00 + h0lambda * w1lambda * input[idx(c, ih0, iw1)] + # h0 * w1 * i01 + h1lambda * w0lambda * input[idx(c, ih1, iw0)] + # h1 * w0 * i10 + h1lambda * w1lambda * input[idx(c, ih1, iw1)]) # h1 * w1 * i11 + end end - # weight1 = cat(fill(kern1, n_chan)..., dims=(3,4)) # slow - dx[[1],:,:,:] .+= conv(Δ[1:nextras[1],:,:,:], weight1, pad=pad1, stride=stride1) - weight1 .= weight1[end:-1:1,:,:,:] - dx[[end],:,:,:] .+= conv(Δ[end-nextras[1]+1:end,:,:,:], weight1, pad=pad1, stride=stride1) - - ## Conv with views is not dispatched to CUDA.conv - # dx[[1],:,:,:] .+= conv(@view(Δ[1:nextras[1],:,:,:]), weight1, pad=pad1, stride=stride1) - # weight1 .= @view(weight1[end:-1:1,:,:,:]) - # dx[[end],:,:,:] .+= conv(@view(Δ[end-nextras[1]+1:end,:,:,:]), weight1, pad=pad1, stride=stride1) end + return output +end - # Second dimension edge-effect correction - if nextras[2] > 0 - kern2 = kern[:,1:nextras[2]] - pad2 = (pad[1], 0) - stride2 = (stride[1], 1) - weight2 = similar(Δ, eltype(Δ), (size(kern2)..., n_chan, n_chan)) - weight2 .= 0 - for i in 1:n_chan - weight2[:,:,i,i] .= kern2 - end - # weight2 = cat(fill(kern2, n_chan)..., dims=(3,4)) # slow - - yy = conv(Δ[:,1:nextras[2],:,:], weight2, pad=pad2, stride=stride2) - dx[:,[1],:,:] .+= conv(Δ[:,1:nextras[2],:,:], weight2, pad=pad2, stride=stride2) - weight2 .= weight2[:,end:-1:1,:,:] - dx[:,[end],:,:] .+= conv(Δ[:,end-nextras[2]+1:end,:,:], weight2, pad=pad2, stride=stride2) +""" + ∇upsample_bilinear(Δ::AbstractArray{T,4}; size::NTuple{2,Integer}) where T - ## Conv with views is not dispatched to CUDA.conv - # yy = conv(@view(Δ[:,1:nextras[2],:,:]), weight2, pad=pad2, stride=stride2) - # dx[:,[1],:,:] .+= conv(@view(Δ[:,1:nextras[2],:,:]), weight2, pad=pad2, stride=stride2) - # weight2 .= @view(weight2[:,end:-1:1,:,:]) - # dx[:,[end],:,:] .+= conv(@view(Δ[:,end-nextras[2]+1:end,:,:]), weight2, pad=pad2, stride=stride2) - end +# Arguments +- `Δ`: Incoming gradient array, backpropagated from downstream layers +- `size`: Lateral (W,H) size of the image upsampled in the first place - ## Finally fix four corners if needed - n1, n2 = nextras - if (n1 > 0) & (n2 > 0) - dx[1,1,:,:] .+= sum(kern[1:n1,1:n2] .* @view(Δ[1:n1,1:n2,:,:]), dims=(1,2))[1,1,:,:] - dx[1,end,:,:] .+= sum(kern[1:n1,end-n2+1:end] .* @view(Δ[1:n1,end-n2+1:end,:,:]), dims=(1,2))[1,1,:,:] - dx[end,end,:,:] .+= sum(kern[end-n1+1:end,end-n2+1:end] .* @view(Δ[end-n1+1:end,end-n2+1:end,:,:]), dims=(1,2))[1,1,:,:] - dx[end,1,:,:] .+= sum(kern[end-n1+1:end,1:n2] .* @view(Δ[end-n1+1:end,1:n2,:,:]), dims=(1,2))[1,1,:,:] +# Outputs +- `dx`: Downsampled version of `Δ` +""" +function ∇upsample_bilinear(Δ::AbstractArray{T,4}; size::NTuple{2,Integer}) where T + w, h, c, n = Base.size(Δ) + out_w, out_h = size + if (w,h) == (out_w, out_h) + return Δ end - - return dx + dx = zero(similar(Δ, T, out_w, out_h, c, n)) + return ∇upsample_bilinear_whcn!(dx, Δ) end -# `n` upsample factor for which a downsample kernel will be determined. -# Δ is given in case of necessity of gpu conversion -function get_downsamplekernel(Δ, n::Int) - step = 1//n - if n % 2 == 0 - start = step//2 - upward = collect(start:step:1//1) - kernel = [upward; reverse(upward)] - else - start = step - upward = collect(start:step:1//1) - kernel = [upward; reverse(upward[1:end-1])] +function ∇upsample_bilinear_whcn!(dx::AbstractArray{T,4}, Δ::AbstractArray{T,4}) where T + size(dx)[3:4] == size(Δ)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") + in_w, in_h, channels, batches = size(dx) + + # treat batch and channel dimension as one for better parallelization granularity + channels *= batches + out_w, out_h, _, _ = size(Δ) + output_slice_size = out_h * out_w + + width_scale = T((in_w - 1) // (out_w - 1)) + height_scale = T((in_h - 1) // (out_h - 1)) + + @inline idx(c, h, w) = c * in_h * in_w + h * in_w + w + 1 + + @inbounds Threads.@threads for c in 0:channels-1 + for oh in 0:out_h-1 + ih0, ih1, h0lambda, h1lambda = compute_source_index_and_lambda(height_scale, oh, in_h, out_h) + for ow in 0:out_w-1 + iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w) + output_offset = c * output_slice_size + oh * out_w + ow + 1 + Δ_value = Δ[output_offset] + dx[idx(c, ih0, iw0)] += h0lambda * w0lambda * Δ_value # i00 + dx[idx(c, ih0, iw1)] += h0lambda * w1lambda * Δ_value # i01 + dx[idx(c, ih1, iw0)] += h1lambda * w0lambda * Δ_value # i10 + dx[idx(c, ih1, iw1)] += h1lambda * w1lambda * Δ_value # i11 + end + end end - # TODO there must be a more convenient way to send to gpu - kernel = convert(typeof(Δ), reshape(kernel, length(kernel), 1, 1, 1)) - kernel = dropdims(kernel, dims=(2,3,4)) - return kernel + return dx end -function ChainRulesCore.rrule(::typeof(upsample_bilinear), x, k) - Ω = upsample_bilinear(x, k) +function ChainRulesCore.rrule(::typeof(upsample_bilinear), x; size) + Ω = upsample_bilinear(x; size=size) function upsample_bilinear_pullback(Δ) - (NO_FIELDS, ∇upsample_bilinear(Δ, k), DoesNotExist()) + (NO_FIELDS, ∇upsample_bilinear(Δ; size=(Base.size(x,1),Base.size(x,2)))) end return Ω, upsample_bilinear_pullback end - """ pixel_shuffle(x, r) - + Pixel shuffling operation. `r` is the upscale factor for shuffling. The operation converts an input of size [W,H,r²C,N] to size [rW,rH,C,N] -Used extensively in super-resolution networks to upsample +Used extensively in super-resolution networks to upsample towards high resolution features. Reference : https://arxiv.org/pdf/1609.05158.pdf @@ -301,7 +232,7 @@ function pixel_shuffle(x::AbstractArray, r::Integer) @assert ndims(x) > 2 d = ndims(x) - 2 sizein = size(x)[1:d] - cin, n = size(x, d+1), size(x, d+2) + cin, n = size(x, d+1), size(x, d+2) @assert cin % r^d == 0 cout = cin ÷ r^d # x = reshape(x, sizein..., fill(r, d)..., cout, n) # bug https://github.com/FluxML/Zygote.jl/issues/866 diff --git a/test/pooling.jl b/test/pooling.jl index b7fe6c44d..d1d26c620 100644 --- a/test/pooling.jl +++ b/test/pooling.jl @@ -319,7 +319,7 @@ end # test "true" strided case, see https://github.com/FluxML/NNlib.jl/issues/205 -# obtained with +# obtained with # using FiniteDifferences maxpool_answer_nature = Dict( "rank1" => Dict( @@ -327,62 +327,62 @@ maxpool_answer_nature = Dict( "k2s1p0" => (size = (2,), stride = 1, pad = 0, - - x = reshape([ + + x = reshape([ 0.0299635, 0.233456, 0.596161, 0.161514, 0.0094027 ], 5, 1, 1), # width, channel, batch_size - + dx_maxpool = reshape([ 0.0, 1.0, 2.0, 1.0, 0.0 ], 5, 1, 1), - + dx_meanpool = reshape([ 0.5, 1.0, 1.0, 1.0, 0.5 ], 5, 1, 1),), "k2s1p1" => (size = (2,), stride = 1, pad = 1, - - x = reshape([ + + x = reshape([ 0.0299635, 0.233456, 0.596161, 0.161514, 0.0094027 ], 5, 1, 1), - + dx_maxpool = reshape([ 1.0, 1.0, 2.0, 1.0, 1.0 ], 5, 1, 1), - + dx_meanpool = reshape([ 1.0, 1.0, 1.0, 1.0, 1.0 ], 5, 1, 1),), "k3s1p1" => (size = (3,), stride = 1, pad = 1, - - x = reshape([ + + x = reshape([ 0.0299635, 0.233456, 0.596161, 0.161514, 0.0094027 ], 5, 1, 1), - + dx_maxpool = reshape([ 0.0, 1.0, 3.0, 1.0, 0.0 ], 5, 1, 1), - + dx_meanpool = reshape([ 0.6666666666, 1.0, 1.0, 1.0, 0.6666666666 ], 5, 1, 1),), "k3s2p1" => (size = (3,), stride = 2, pad = 1, - - x = reshape([ + + x = reshape([ 0.0299635, 0.233456, 0.596161, 0.161514, 0.0094027 ], 5, 1, 1), - + dx_maxpool = reshape([ 0.0, 1.0, 1.0, 1.0, 0.0 ], 5, 1, 1), - + dx_meanpool = reshape([ - 0.333333333, + 0.333333333, 0.666666666, 0.333333333, 0.666666666, @@ -395,7 +395,7 @@ maxpool_answer_nature = Dict( stride = 1, pad = 0, - x = reshape([ + x = reshape([ 0.0299635 0.233456 0.596161 0.161514 0.0094027 0.389984 0.235158 0.579525 0.301893 0.561358 0.0830242 0.483759 0.914904 0.253871 0.820061 @@ -408,9 +408,9 @@ maxpool_answer_nature = Dict( 1.0 0.0 0.0 0.0 1.0 0.0 1.0 4.0 0.0 2.0 0.0 1.0 0.0 2.0 0.0 - 0.0 2.0 0.0 0.0 0.0 + 0.0 2.0 0.0 0.0 0.0 ], 5, 5, 1, 1), - + dx_meanpool = reshape([ 0.25 0.5 0.5 0.5 0.25 0.5 1.0 1.0 1.0 0.5 @@ -421,8 +421,8 @@ maxpool_answer_nature = Dict( "k2s1p1" => (size = (2, 2), stride = 1, pad = 1, - - x = reshape([ + + x = reshape([ 0.0299635 0.233456 0.596161 0.161514 0.0094027 0.389984 0.235158 0.579525 0.301893 0.561358 0.0830242 0.483759 0.914904 0.253871 0.820061 @@ -437,7 +437,7 @@ maxpool_answer_nature = Dict( 1.0 1.0 0.0 2.0 0.0 2.0 4.0 1.0 0.0 3.0 ], 5, 5, 1, 1), - + dx_meanpool = reshape([ 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 @@ -449,7 +449,7 @@ maxpool_answer_nature = Dict( stride = 1, pad = 1, - x = reshape([ + x = reshape([ 0.0299635 0.233456 0.596161 0.161514 0.0094027 0.389984 0.235158 0.579525 0.301893 0.561358 0.0830242 0.483759 0.914904 0.253871 0.820061 @@ -464,7 +464,7 @@ maxpool_answer_nature = Dict( 0.0 1.0 0.0 3.0 0.0 0.0 3.0 0.0 0.0 0.0 ], 5, 5, 1, 1), - + dx_meanpool = reshape([ 0.444444 0.666667 0.666667 0.666667 0.444444 0.666667 1.0 1.0 1.0 0.666667 @@ -476,7 +476,7 @@ maxpool_answer_nature = Dict( stride = 2, pad = 1, - x = reshape([ + x = reshape([ 0.0299635 0.233456 0.596161 0.161514 0.0094027 0.389984 0.235158 0.579525 0.301893 0.561358 0.0830242 0.483759 0.914904 0.253871 0.820061 @@ -491,7 +491,7 @@ maxpool_answer_nature = Dict( 0.0 1.0 0.0 2.0 0.0 0.0 1.0 0.0 0.0 0.0 ], 5, 5, 1, 1), - + dx_meanpool = reshape([ 0.111111 0.222222 0.111111 0.222222 0.111111 0.222222 0.444444 0.222222 0.444444 0.222222 @@ -505,7 +505,7 @@ maxpool_answer_nature = Dict( "k2s1p0" => (size = (2, 2, 2), stride = 1, pad = 0, - + x = reshape(cat([ 0.82584 0.416818 0.92668 0.471931 0.798798 0.131608 0.344556 0.79681 @@ -524,7 +524,7 @@ maxpool_answer_nature = Dict( 0.640221 0.28811 0.129229 0.97571 0.953795 0.1316 0.94538 0.705337 ],dims=3), 4,4,3,1,1), - + dx_maxpool = reshape(cat([ 1.0 0.0 2.0 0.0 1.0 0.0 0.0 0.0 @@ -543,7 +543,7 @@ maxpool_answer_nature = Dict( 0.0 0.0 0.0 2.0 1.0 0.0 1.0 0.0 ],dims=3), 4,4,3,1,1), - + dx_meanpool = reshape(cat([ 0.125 0.25 0.25 0.125 0.25 0.5 0.5 0.25 @@ -565,7 +565,7 @@ maxpool_answer_nature = Dict( "k2s1p1" => (size = (2, 2, 2), stride = 1, pad = 1, - + x = reshape(cat([ 0.82584 0.416818 0.92668 0.471931 0.798798 0.131608 0.344556 0.79681 @@ -584,7 +584,7 @@ maxpool_answer_nature = Dict( 0.640221 0.28811 0.129229 0.97571 0.953795 0.1316 0.94538 0.705337 ],dims=3), 4,4,3,1,1), - + dx_maxpool = reshape(cat([ 8.0 0.0 8.0 2.0 4.0 0.0 1.0 4.0 @@ -603,7 +603,7 @@ maxpool_answer_nature = Dict( 3.0 0.0 0.0 8.0 8.0 0.0 6.0 1.0 ],dims=3), 4,4,3,1,1), - + dx_meanpool = reshape(cat([ 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 @@ -625,7 +625,7 @@ maxpool_answer_nature = Dict( "k3s1p1" => (size = (3, 3, 2), stride = 1, pad = 1, - + x = reshape(cat([ 0.82584 0.416818 0.92668 0.471931 0.798798 0.131608 0.344556 0.79681 @@ -644,7 +644,7 @@ maxpool_answer_nature = Dict( 0.640221 0.28811 0.129229 0.97571 0.953795 0.1316 0.94538 0.705337 ],dims=3), 4,4,3,1,1), - + dx_maxpool = reshape(cat([ 4.0 0.0 12.0 0.0 3.0 0.0 0.0 2.0 @@ -663,7 +663,7 @@ maxpool_answer_nature = Dict( 0.0 0.0 0.0 12.0 8.0 0.0 0.0 0.0 ],dims=3), 4,4,3,1,1), - + dx_meanpool = reshape(cat([ 0.444444 0.666667 0.666667 0.444444 0.666667 1.0 1.0 0.666667 @@ -685,7 +685,7 @@ maxpool_answer_nature = Dict( "k3s2p1" => (size = (3, 3, 2), stride = 2, pad = 1, - + x = reshape(cat([ 0.82584 0.416818 0.92668 0.471931 0.798798 0.131608 0.344556 0.79681 @@ -704,7 +704,7 @@ maxpool_answer_nature = Dict( 0.640221 0.28811 0.129229 0.97571 0.953795 0.1316 0.94538 0.705337 ],dims=3), 4,4,3,1,1), - + dx_maxpool = reshape(cat([ 1.0 0.0 1.0 0.0 1.0 0.0 0.0 1.0 @@ -723,7 +723,7 @@ maxpool_answer_nature = Dict( 0.0 0.0 0.0 1.0 1.0 0.0 0.0 0.0 ],dims=3), 4,4,3,1,1), - + dx_meanpool = reshape(cat([ 0.0555556 0.111111 0.0555556 0.0555556 0.111111 0.222222 0.111111 0.111111 @@ -750,7 +750,7 @@ maxpool_answer_nature = Dict( # issue #205 function check(config, T) # CHECK DEFAULT - pdims = PoolDims(config.x, config.size; stride=config.stride, padding=config.pad) + pdims = PoolDims(config.x, config.size; stride=config.stride, padding=config.pad) x = T.(config.x) y_maxpool = NNlib.maxpool(x, pdims) y_meanpool = NNlib.meanpool(x, pdims) diff --git a/test/upsample.jl b/test/upsample.jl index 08886b4d0..7dec97789 100644 --- a/test/upsample.jl +++ b/test/upsample.jl @@ -13,32 +13,65 @@ end @testset "upsample_bilinear 2d" begin - x = reshape(Float32[1. 2.; 3. 4.], (2,2,1,1)) - y_true = [1//1 5//4 7//4 2//1; - 1//1 5//4 7//4 2//1; - 5//3 23//12 29//12 8//3; - 7//3 31//12 37//12 10//3; - 3//1 13//4 15//4 4//1; - 3//1 13//4 15//4 4//1][:,:,:,:] - + x = Float32[1 2; 3 4][:,:,:,:] + x = cat(x,x; dims=3) + x = cat(x,x; dims=4) + + # this output matches the one of pytorch v1.5.0 + # nn.UpsamplingBilinear2d(scale_factor=(3,2), align_corners=True) + # for above x + y_true = Float32[ 1//1 4//3 5//3 2//1; + 7//5 26//15 31//15 12//5; + 9//5 32//15 37//15 14//5; + 11//5 38//15 43//15 16//5; + 13//5 44//15 49//15 18//5; + 3//1 10//3 11//3 4//1][:,:,:,:] + y_true = cat(y_true,y_true; dims=3) + y_true = cat(y_true,y_true; dims=4) + y = upsample_bilinear(x, (3, 2)) @test size(y) == size(y_true) @test eltype(y) == Float32 @test y ≈ y_true - gradtest(x->upsample_bilinear(x, (3, 2)), x, atol=1e-4) - - if CUDA.has_cuda() - y = upsample_bilinear(x |> cu, (3, 2)) - @test y isa CuArray - @test Array(y) ≈ y_true - g_gpu = Zygote.gradient(x -> sum(sin.(upsample_bilinear(x, (3, 2)))) - , x |> cu)[1] - @test g_gpu isa CuArray - g_cpu = Zygote.gradient(x -> sum(sin.(upsample_bilinear(x, (3, 2)))) - , x)[1] - @test Array(g_cpu) ≈ g_cpu atol=1e-4 - end + gradtest(x->upsample_bilinear(x, (3, 2)), x, atol=1e-3) # works to higher precision for Float64 + + # additional grad check, also compliant with pytorch + o = ones(Float32,6,4,2,1) + grad_true = 6*ones(Float32,2,2,2,1) + @test ∇upsample_bilinear(o; size = (2,2)) ≈ grad_true + + y_true_2 = Rational{Int}[1//1 5//4 6//4 7//4 2//1; + 3//2 7//4 8//4 9//4 5//2; + 4//2 9//4 10//4 11//4 6//2; + 5//2 11//4 12//4 13//4 7//2; + 3//1 13//4 14//4 15//4 4//1][:,:,:,:] + + # check for real-valued single-number argument and type stability for rationals + upsample_bilinear(x, 2.5) == y_true_2 + + # check Integer support for forward pass + # grads are always assumed to be floats, so no extension there + x = UInt8[1 3; 3 5][:,:,:,:] + y_true_int = UInt8[1 2 3; 2 3 4; 3 4 5][:,:,:,:] + y = upsample_bilinear(x, 1.5) + + @test eltype(y) == UInt8 + @test y == y_true_int + + # this test can be performed again, as soon as the corresponding CUDA functionality is merged + + # if CUDA.has_cuda() + # y = upsample_bilinear(x |> cu, (3, 2)) + # @test y isa CuArray + # @test Array(y) ≈ y_true + # g_gpu = Zygote.gradient(x -> sum(sin.(upsample_bilinear(x, (3, 2)))) + # , x |> cu)[1] + # @test g_gpu isa CuArray + # g_cpu = Zygote.gradient(x -> sum(sin.(upsample_bilinear(x, (3, 2)))) + # , x)[1] + # @test Array(g_cpu) ≈ g_cpu atol=1e-4 + # end end @testset "pixel_shuffle" begin @@ -60,7 +93,7 @@ end 5 13 7 15 2 10 4 12 6 14 8 16][:,:,:,:] - + y = pixel_shuffle(x, 2) @test size(y) == size(y_true) @test y_true == y @@ -84,7 +117,7 @@ end x = reshape(1:4*3*27*2, (4,3,27,2)) y = pixel_shuffle(x, 3) @test size(y) == (12, 9, 3, 2) - # batch dimension is preserved + # batch dimension is preserved x1 = x[:,:,:,[1]] x2 = x[:,:,:,[2]] y1 = pixel_shuffle(x1, 3) @@ -97,7 +130,7 @@ end c = rand(1:5) insize = rand(1:5, d) x = rand(insize..., r^d*c, n) - + y = pixel_shuffle(x, r) @test size(y) == ((r .* insize)..., c, n)