From 6520ed3418e564aac55d23a083bf413ca264ae89 Mon Sep 17 00:00:00 2001 From: Max Freudenberg Date: Mon, 11 Jan 2021 16:11:42 +0100 Subject: [PATCH 01/21] improve bilinear upsampling --- src/upsample.jl | 370 ++++++++++++++++++++++------------------------- test/upsample.jl | 52 ++++--- 2 files changed, 204 insertions(+), 218 deletions(-) diff --git a/src/upsample.jl b/src/upsample.jl index 860e8441b..43bea9f65 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -1,234 +1,214 @@ export upsample_bilinear, ∇upsample_bilinear, pixel_shuffle +# utility function +@inline function compute_source_index_and_lambda( + ratio, # 0 < ratio < 1 + output_index, + input_size, + output_size +) + real_input_index = ratio*output_index + input_index0 = floor(Int, real_input_index) # typecast to int was here in C++ + offset = (input_index0 < input_size - 1) ? 1 : 0 + input_index1 = input_index0 + offset + lambda1 = real_input_index - input_index0 + lambda0 = 1 - lambda1 + return input_index0, input_index1, lambda0, lambda1 +end + """ - upsample_bilinear(x::AbstractArray{<:Number,4}, k::NTuple{2,Int}) + upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); outsize::Union{Nothing,NTuple{2,Integer}}=nothing) -Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `k`, -using bilinear interpolation. +Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `scale`, +using bilinear interpolation. -The size of the output is equal to -`(k[1]*S1, k[2]*S2, S3, S4)`, where `S1, S2, S3, S4 = size(x)`. +The size of the output is equal to +`(scale[1]*S1, scale[2]*S2, S3, S4)`, where `S1, S2, S3, S4 = size(x)`. -The interpolation grid is identical to the one used by `imresize` from `Images.jl`. +The interpolation grid is identical to the one used by open-cv. Currently only 2d upsampling is supported. """ -function upsample_bilinear(x::AbstractArray{T,4}, k::NTuple{2,Int}) where T - # This function is gpu friendly - - imgsize = size(x) - newsize = get_newsize(imgsize, k) - - # Get linear interpolation lower- and upper index, and weights - ilow1, ihigh1, wdiff1 = get_inds_and_ws(x, imgsize[1], newsize[1], 1) - ilow2, ihigh2, wdiff2 = get_inds_and_ws(x, imgsize[2], newsize[2], 2) - - # Adjust the upper interpolation indices of the second dimension - ihigh2_r = adjoint_of_idx(ilow2)[ihigh2] - - @inbounds y = @view(x[ilow1,ilow2,:,:]) .* (1 .- wdiff1) .+ @view(x[ihigh1,ilow2,:,:]) .* wdiff1 - @inbounds y .= y .* (1 .- wdiff2) .+ y[:,ihigh2_r,:,:] .* wdiff2 - # @inbounds y = y .* (1 .- wdiff2) .+ @view(y[:,ihigh2_r,:,:]) .* wdiff2 # equivalent to line above - return y -end - -function get_inds_and_ws(x::T, n::Int, m::Int, dim::Int) where T <: AbstractArray - # Creates interpolation grid for resampling. - # Creates the same grid as used in Image.jl `imresize`. - step = n // m - offset = (n + 1)//2 - step//2 - step * (m//2 - 1) - xq = clamp.(range(offset, step=step, length=m), 1, n) - - # Creates interpolation lower and upper indices, and broadcastable weights - ilow = floor.(Int, xq) - ihigh = ceil.(Int, xq) - sizew = ntuple(i-> i == dim ? length(xq) : 1, ndims(x)) - wdiff = convert(T, reshape(xq .- ilow, sizew)) # wdiff possibly lives on gpu - return ilow, ihigh, wdiff +function upsample_bilinear end + +# this is the user-facing part +function upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); outsize::Union{Nothing,NTuple{2,Integer}}=nothing) where T + w,h,c,n = size(x) + if outsize===nothing + out_w = floor(Int, scale[1]*w) + out_h = floor(Int, scale[2]*h) + else + out_w, out_h = outsize + end + y = Array{T,4}(undef, out_w, out_h, c, n) + return upsample_bilinear_whcn!(y, x) end -""" - adjoint_of_idx(idx::Vector{<:Integer}) - -# Arguments -- `idx`: a vector of indices from which you want the adjoint. - -# Outputs --`idx_adjoint`: index that inverses the operation `x[idx]`. - -# Explanation -Determines the adjoint of the vector of indices `idx`, based on the following assumptions: -* `idx[1] == 1` -* `all(d in [0,1] for d in diff(idx))` -The adjoint of `idx` can be seen as an inverse operation such that: - -```julia -x = [1, 2, 3, 4, 5] -idx = [1, 2, 2, 3, 4, 4, 5] -idx_adjoint = adjoint_of_idx(idx) -@assert x[idx][idx_adjoint] == x -``` -The above holds as long as `idx` contains every index in `x`. -""" -function adjoint_of_idx(idx::Vector{Int}) - d = trues(length(idx)) - d[2:end] .= diff(idx) - idx_adjoint = findall(d) - return idx_adjoint +# this is the core function which works on arrays of arbitrary size +# the implementation is a translation of https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/UpSampleMoreKernel.cpp +# which implements open-cv style linear interpolation / upsampling +# for simplicity, corners are aligned and all logic for other behaviour has been stripped +# - whcn because there is also a cwhn implementation +# - the function is parallelized using @threads +# - RGB types could be supported via reinterpreting +# - integer types need to be converted to Float and back +# - rationals work, but are slow +function upsample_bilinear_whcn!(output::AbstractArray{T,4}, input::AbstractArray{T,4}) where T + if size(input) == size(output) + return input + end + size(input)[3:4] == size(output)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") + in_w, in_h, channels, batches = size(input) + # treat batch and channel dimension as one for better parallelization granularity + channels *= batches + out_w, out_h, _, _ = size(output) + output_slice_size = out_h * out_w + + # T() and // so that we can handle rationals (super slow) + width_scale = T((in_w - 1) // (out_w - 1)) + height_scale = T((in_h - 1) // (out_h - 1)) + + @inline idx(c, h, w) = c * in_h * in_w + h * in_w + w + 1 + + @inbounds Threads.@threads for c in 0:channels-1 + for oh in 0:out_h-1 + ih0, ih1, h0lambda, h1lambda = compute_source_index_and_lambda(height_scale, oh, in_h, out_h) + for ow in 0:out_w-1 + iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w) + output_offset = c * output_slice_size + oh * out_w + ow + 1 + output[output_offset] = + (h0lambda * w0lambda * input[idx(c, ih0, iw0)] + # h0 * w0 * i00 + h0lambda * w1lambda * input[idx(c, ih0, iw1)] + # h0 * w1 * i01 + h1lambda * w0lambda * input[idx(c, ih1, iw0)] + # h1 * w0 * i10 + h1lambda * w1lambda * input[idx(c, ih1, iw1)]) # h1 * w1 * i11 + end + end + end + return output end -function get_newsize(sz, k) - return ntuple(i -> i <= length(k) ? sz[i]*k[i] : sz[i], length(sz)) +# this is just for convenience, to support matrices + +# function upsample_bilinear_whcn!(output::AbstractMatrix, input::AbstractMatrix) +# y = reshape(output, (size(output)..., 1, 1)) +# x = reshape(input, (size(input)..., 1, 1)) +# return dropdims(upsample_bilinear_whcn!(y, x); dims=(3,4)) +# end + +# function upsample_bilinear(x::AbstractArray{T,2}, scale::NTuple{2,Real}=(1,1); outsize::Union{Nothing,NTuple{2,Integer}}=nothing) where T +# w,h = size(x) +# if outsize===nothing +# out_w = floor(Int, scale[1]*w) +# out_h = floor(Int, scale[2]*h) +# else +# out_w, out_h = outsize +# end +# y = Array{T,2}(undef, out_w, out_h) +# return upsample_bilinear_whcn!(y,x) +# end + + +# here is a cwhn implementation for later use +# much faster than whcn on single core, multi-threaded they are about the same, depends on n and c +# innermost channel loop should vectorize automatically +function upsample_bilinear_cwhn!(output::AbstractArray{T,4},input::AbstractArray{T,4}) where T + if size(input) == size(output) + return input + end + size(input)[[1,4]] == size(output)[[1,4]] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") + channels, in_w, in_h, batches = size(input) + out_w, out_h = size(output)[[2,3]] + + width_scale = (in_w - 1) / (out_w - 1) # area_pixel_compute_scale(in_w, out_w, align_corners, scales[1]); + height_scale = (in_h - 1) / (out_h - 1) # area_pixel_compute_scale(in_h, out_h, align_corners, scales[2]); + + # might make more sense to parallelize an inner loop in case batches=1 + @inbounds Threads.@threads for n in 1:batches + for oh in 0:out_h-1 + ih0, ih1, h0lambda, h1lambda = compute_source_index_and_lambda(height_scale, oh, in_h, out_h) + for ow in 0:out_w-1 + iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w) + for c in 1:channels + output[c, ow+1, oh+1, n] = + (h0lambda * w0lambda * input[c, iw0+1, ih0+1, n] + # h0 * w0 * i00 + h0lambda * w1lambda * input[c, iw1+1, ih0+1, n] + # h0 * w1 * i01 + h1lambda * w0lambda * input[c, iw0+1, ih1+1, n] + # h1 * w0 * i10 + h1lambda * w1lambda * input[c, iw1+1, ih1+1, n]) # h1 * w1 * i11 + end + end + end + end + return output end - """ - ∇upsample_bilinear(Δ::AbstractArray{<:Number,4}, k::NTuple{2,Int}) - + ∇upsample_bilinear(Δ::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); outsize::Union{Nothing,NTuple{2,Integer}}=nothing) where T + # Arguments -- `Δ`: array that has been upsampled using the upsample factors in `k` +- `Δ`: incoming gradient array that has been upsampled using the upsample factors in `scale` # Outputs - `dx`: downsampled version of `Δ` - -# Explanation - -Custom adjoint for [`upsample_bilinear`](@ref). -The adjoint of upsampling is a downsampling operation, which -in this implementation is performed using `NNlib.conv` in combination with a downsampling kernel based on the -upsampling factors. Because of the zero-padding during convolution, the values at the boundary are polluted by edge-effects, -which have been corrected for manually. """ -function ∇upsample_bilinear(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int}) - # This function is gpu friendly - - # Be more efficient on some corner cases - if size(Δ, 1) == k[1] - Δ = sum(Δ, dims=1) - k = (1, k[2]) - end - if size(Δ, 2) == k[2] - Δ = sum(Δ, dims=2) - k = (k[1], 1) - end - if (size(Δ, 1) == 1) && (size(Δ, 2) == 1) - dx = Δ - return dx - end - - n_chan, n_batch = size(Δ, 3), size(Δ, 4) - - kern1 = get_downsamplekernel(Δ, k[1]) - kern2 = get_downsamplekernel(Δ, k[2]) - kern = kern1 * kern2' - - pad = (floor(Int, k[1]//2), floor(Int, k[2]//2)) - stride = k - - weight = similar(Δ, eltype(Δ), (size(kern)..., n_chan, n_chan)) - weight .= 0 - for i in 1:n_chan - weight[:,:,i,i] .= kern - end - # weight = cat(fill(kern, n_chan)..., dims=(3,4)) # slow - dx = conv(Δ, weight, pad=pad, stride=stride) - - # Still have to fix edge effects due to zero-padding of convolution, - # TODO: Could be circumvented by having padding that just extrapolates the value at the first/last index - # nextras = tuple((Int.(floor(factor//2)) for factor in k)...) - nextras = (floor(Int, k[1]//2), floor(Int, k[2]//2)) - - # First dimension edge-effect correction - if nextras[1] > 0 - kern1 = kern[1:nextras[1],:] - pad1 = (0, pad[2]) - stride1 = (1, stride[2]) - weight1 = similar(Δ, eltype(Δ), (size(kern1)..., n_chan, n_chan)) - weight1 .= 0 - for i in 1:n_chan - weight1[:,:,i,i] .= kern1 - end - # weight1 = cat(fill(kern1, n_chan)..., dims=(3,4)) # slow - dx[[1],:,:,:] .+= conv(Δ[1:nextras[1],:,:,:], weight1, pad=pad1, stride=stride1) - weight1 .= weight1[end:-1:1,:,:,:] - dx[[end],:,:,:] .+= conv(Δ[end-nextras[1]+1:end,:,:,:], weight1, pad=pad1, stride=stride1) - - ## Conv with views is not dispatched to CUDA.conv - # dx[[1],:,:,:] .+= conv(@view(Δ[1:nextras[1],:,:,:]), weight1, pad=pad1, stride=stride1) - # weight1 .= @view(weight1[end:-1:1,:,:,:]) - # dx[[end],:,:,:] .+= conv(@view(Δ[end-nextras[1]+1:end,:,:,:]), weight1, pad=pad1, stride=stride1) - end +function ∇upsample_bilinear end - # Second dimension edge-effect correction - if nextras[2] > 0 - kern2 = kern[:,1:nextras[2]] - pad2 = (pad[1], 0) - stride2 = (stride[1], 1) - weight2 = similar(Δ, eltype(Δ), (size(kern2)..., n_chan, n_chan)) - weight2 .= 0 - for i in 1:n_chan - weight2[:,:,i,i] .= kern2 - end - # weight2 = cat(fill(kern2, n_chan)..., dims=(3,4)) # slow - - yy = conv(Δ[:,1:nextras[2],:,:], weight2, pad=pad2, stride=stride2) - dx[:,[1],:,:] .+= conv(Δ[:,1:nextras[2],:,:], weight2, pad=pad2, stride=stride2) - weight2 .= weight2[:,end:-1:1,:,:] - dx[:,[end],:,:] .+= conv(Δ[:,end-nextras[2]+1:end,:,:], weight2, pad=pad2, stride=stride2) - - ## Conv with views is not dispatched to CUDA.conv - # yy = conv(@view(Δ[:,1:nextras[2],:,:]), weight2, pad=pad2, stride=stride2) - # dx[:,[1],:,:] .+= conv(@view(Δ[:,1:nextras[2],:,:]), weight2, pad=pad2, stride=stride2) - # weight2 .= @view(weight2[:,end:-1:1,:,:]) - # dx[:,[end],:,:] .+= conv(@view(Δ[:,end-nextras[2]+1:end,:,:]), weight2, pad=pad2, stride=stride2) - end - - ## Finally fix four corners if needed - n1, n2 = nextras - if (n1 > 0) & (n2 > 0) - dx[1,1,:,:] .+= sum(kern[1:n1,1:n2] .* @view(Δ[1:n1,1:n2,:,:]), dims=(1,2))[1,1,:,:] - dx[1,end,:,:] .+= sum(kern[1:n1,end-n2+1:end] .* @view(Δ[1:n1,end-n2+1:end,:,:]), dims=(1,2))[1,1,:,:] - dx[end,end,:,:] .+= sum(kern[end-n1+1:end,end-n2+1:end] .* @view(Δ[end-n1+1:end,end-n2+1:end,:,:]), dims=(1,2))[1,1,:,:] - dx[end,1,:,:] .+= sum(kern[end-n1+1:end,1:n2] .* @view(Δ[end-n1+1:end,1:n2,:,:]), dims=(1,2))[1,1,:,:] +function ∇upsample_bilinear(Δ::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); outsize::Union{Nothing,NTuple{2,Integer}}=nothing) where T + w,h,c,n = size(Δ) + if outsize===nothing + out_w = ceil(Int, w/scale[1]) + out_h = ceil(Int, h/scale[2]) + else + out_w, out_h = outsize end - - return dx + dx = zeros(T, out_w, out_h, c, n) + return ∇upsample_bilinear_whcn!(Δ, dx) end -# `n` upsample factor for which a downsample kernel will be determined. -# Δ is given in case of necessity of gpu conversion -function get_downsamplekernel(Δ, n::Int) - step = 1//n - if n % 2 == 0 - start = step//2 - upward = collect(start:step:1//1) - kernel = [upward; reverse(upward)] - else - start = step - upward = collect(start:step:1//1) - kernel = [upward; reverse(upward[1:end-1])] +function ∇upsample_bilinear_whcn!(Δ::AbstractArray{T,4}, grad_input::AbstractArray{T,4}) where T + size(grad_input)[3:4] == size(Δ)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") + in_w, in_h, channels, batches = size(grad_input) + + # treat batch and channel dimension as one for better parallelization granularity + channels *= batches + out_w, out_h, _, _ = size(Δ) + output_slice_size = out_h * out_w + + width_scale = T((in_w - 1) // (out_w - 1)) # area_pixel_compute_scale(in_w, out_w, align_corners, scales[1]); + height_scale = T((in_h - 1) // (out_h - 1)) # area_pixel_compute_scale(in_h, out_h, align_corners, scales[2]); + + @inline idx(c, h, w) = c * in_h * in_w + h * in_w + w + 1 + + @inbounds Threads.@threads for c in 0:channels-1 + for oh in 0:out_h-1 + ih0, ih1, h0lambda, h1lambda = compute_source_index_and_lambda(height_scale, oh, in_h, out_h) + for ow in 0:out_w-1 + iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w) + output_offset = c * output_slice_size + oh * out_w + ow + 1 + Δ_value = Δ[output_offset] + grad_input[idx(c, ih0, iw0)] += h0lambda * w0lambda * Δ_value # i00 + grad_input[idx(c, ih0, iw1)] += h0lambda * w1lambda * Δ_value # i01 + grad_input[idx(c, ih1, iw0)] += h1lambda * w0lambda * Δ_value # i10 + grad_input[idx(c, ih1, iw1)] += h1lambda * w1lambda * Δ_value # i11 + end + end end - # TODO there must be a more convenient way to send to gpu - kernel = convert(typeof(Δ), reshape(kernel, length(kernel), 1, 1, 1)) - kernel = dropdims(kernel, dims=(2,3,4)) - return kernel + return grad_input end -function ChainRulesCore.rrule(::typeof(upsample_bilinear), x, k) - Ω = upsample_bilinear(x, k) +function ChainRulesCore.rrule(::typeof(upsample_bilinear), x, scale; outsize=nothing) + Ω = upsample_bilinear(x, scale; outsize=outsize) function upsample_bilinear_pullback(Δ) - (NO_FIELDS, ∇upsample_bilinear(Δ, k), DoesNotExist()) + (NO_FIELDS, ∇upsample_bilinear(Δ, scale; outsize=(size(x,1),size(x,2))), DoesNotExist()) end return Ω, upsample_bilinear_pullback end - """ pixel_shuffle(x, r) - + Pixel shuffling operation. `r` is the upscale factor for shuffling. The operation converts an input of size [W,H,r²C,N] to size [rW,rH,C,N] -Used extensively in super-resolution networks to upsample +Used extensively in super-resolution networks to upsample towards high resolution features. Reference : https://arxiv.org/pdf/1609.05158.pdf @@ -237,7 +217,7 @@ function pixel_shuffle(x::AbstractArray, r::Integer) @assert ndims(x) > 2 d = ndims(x) - 2 sizein = size(x)[1:d] - cin, n = size(x, d+1), size(x, d+2) + cin, n = size(x, d+1), size(x, d+2) @assert cin % r^d == 0 cout = cin ÷ r^d # x = reshape(x, sizein..., fill(r, d)..., cout, n) # bug https://github.com/FluxML/Zygote.jl/issues/866 diff --git a/test/upsample.jl b/test/upsample.jl index 6d07a1344..3c4c7e6f0 100644 --- a/test/upsample.jl +++ b/test/upsample.jl @@ -1,12 +1,17 @@ @testset "upsample_bilinear 2d" begin - x = reshape(Float32[1. 2.; 3. 4.], (2,2,1,1)) - y_true = [1//1 5//4 7//4 2//1; - 1//1 5//4 7//4 2//1; - 5//3 23//12 29//12 8//3; - 7//3 31//12 37//12 10//3; - 3//1 13//4 15//4 4//1; - 3//1 13//4 15//4 4//1][:,:,:,:] - + x = Float32[1 2; 3 4][:,:,:,:] + x = cat(x,x; dims=3) + x = cat(x,x; dims=4) + + y_true = [ 1//1 4//3 5//3 2//1; + 7//5 26//15 31//15 12//5; + 9//5 32//15 37//15 14//5; + 11//5 38//15 43//15 16//5; + 13//5 44//15 49//15 18//5; + 3//1 10//3 11//3 4//1] + y_true = cat(y_true,y_true; dims=3) + y_true = cat(y_true,y_true; dims=4) + y = upsample_bilinear(x, (3, 2)) @test size(y) == size(y_true) @test eltype(y) == Float32 @@ -14,17 +19,19 @@ gradtest(x->upsample_bilinear(x, (3, 2)), x, atol=1e-4) - if CUDA.has_cuda() - y = upsample_bilinear(x |> cu, (3, 2)) - @test y isa CuArray - @test Array(y) ≈ y_true - g_gpu = Zygote.gradient(x -> sum(sin.(upsample_bilinear(x, (3, 2)))) - , x |> cu)[1] - @test g_gpu isa CuArray - g_cpu = Zygote.gradient(x -> sum(sin.(upsample_bilinear(x, (3, 2)))) - , x)[1] - @test Array(g_cpu) ≈ g_cpu atol=1e-4 - end + # this test can be performed again, as soon as the corresponding CUDA functionality is merged + + # if CUDA.has_cuda() + # y = upsample_bilinear(x |> cu, (3, 2)) + # @test y isa CuArray + # @test Array(y) ≈ y_true + # g_gpu = Zygote.gradient(x -> sum(sin.(upsample_bilinear(x, (3, 2)))) + # , x |> cu)[1] + # @test g_gpu isa CuArray + # g_cpu = Zygote.gradient(x -> sum(sin.(upsample_bilinear(x, (3, 2)))) + # , x)[1] + # @test Array(g_cpu) ≈ g_cpu atol=1e-4 + # end end @testset "pixel_shuffle" begin @@ -46,7 +53,7 @@ end 5 13 7 15 2 10 4 12 6 14 8 16][:,:,:,:] - + y = pixel_shuffle(x, 2) @test size(y) == size(y_true) @test y_true == y @@ -70,7 +77,7 @@ end x = reshape(1:4*3*27*2, (4,3,27,2)) y = pixel_shuffle(x, 3) @test size(y) == (12, 9, 3, 2) - # batch dimension is preserved + # batch dimension is preserved x1 = x[:,:,:,[1]] x2 = x[:,:,:,[2]] y1 = pixel_shuffle(x1, 3) @@ -83,11 +90,10 @@ end c = rand(1:5) insize = rand(1:5, d) x = rand(insize..., r^d*c, n) - + y = pixel_shuffle(x, r) @test size(y) == ((r .* insize)..., c, n) gradtest(x -> pixel_shuffle(x, r), x) end end - From 6f63fcea581b52dc0958934f60f6ca5b750fb141 Mon Sep 17 00:00:00 2001 From: Max Freudenberg Date: Mon, 11 Jan 2021 17:31:19 +0100 Subject: [PATCH 02/21] cleanup --- src/upsample.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/upsample.jl b/src/upsample.jl index 43bea9f65..bb50ccb44 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -29,9 +29,6 @@ The interpolation grid is identical to the one used by open-cv. Currently only 2d upsampling is supported. """ -function upsample_bilinear end - -# this is the user-facing part function upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); outsize::Union{Nothing,NTuple{2,Integer}}=nothing) where T w,h,c,n = size(x) if outsize===nothing @@ -150,8 +147,6 @@ end # Outputs - `dx`: downsampled version of `Δ` """ -function ∇upsample_bilinear end - function ∇upsample_bilinear(Δ::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); outsize::Union{Nothing,NTuple{2,Integer}}=nothing) where T w,h,c,n = size(Δ) if outsize===nothing From 37f0eec7bbf3a48136c52c96b76486d1bca91997 Mon Sep 17 00:00:00 2001 From: Max Freudenberg Date: Tue, 12 Jan 2021 14:20:36 +0100 Subject: [PATCH 03/21] remove unused code/comments, relax test atol --- src/upsample.jl | 58 ++---------------------------------------------- test/upsample.jl | 14 ++++++------ 2 files changed, 9 insertions(+), 63 deletions(-) diff --git a/src/upsample.jl b/src/upsample.jl index bb50ccb44..580d0af20 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -84,60 +84,6 @@ function upsample_bilinear_whcn!(output::AbstractArray{T,4}, input::AbstractArra return output end -# this is just for convenience, to support matrices - -# function upsample_bilinear_whcn!(output::AbstractMatrix, input::AbstractMatrix) -# y = reshape(output, (size(output)..., 1, 1)) -# x = reshape(input, (size(input)..., 1, 1)) -# return dropdims(upsample_bilinear_whcn!(y, x); dims=(3,4)) -# end - -# function upsample_bilinear(x::AbstractArray{T,2}, scale::NTuple{2,Real}=(1,1); outsize::Union{Nothing,NTuple{2,Integer}}=nothing) where T -# w,h = size(x) -# if outsize===nothing -# out_w = floor(Int, scale[1]*w) -# out_h = floor(Int, scale[2]*h) -# else -# out_w, out_h = outsize -# end -# y = Array{T,2}(undef, out_w, out_h) -# return upsample_bilinear_whcn!(y,x) -# end - - -# here is a cwhn implementation for later use -# much faster than whcn on single core, multi-threaded they are about the same, depends on n and c -# innermost channel loop should vectorize automatically -function upsample_bilinear_cwhn!(output::AbstractArray{T,4},input::AbstractArray{T,4}) where T - if size(input) == size(output) - return input - end - size(input)[[1,4]] == size(output)[[1,4]] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") - channels, in_w, in_h, batches = size(input) - out_w, out_h = size(output)[[2,3]] - - width_scale = (in_w - 1) / (out_w - 1) # area_pixel_compute_scale(in_w, out_w, align_corners, scales[1]); - height_scale = (in_h - 1) / (out_h - 1) # area_pixel_compute_scale(in_h, out_h, align_corners, scales[2]); - - # might make more sense to parallelize an inner loop in case batches=1 - @inbounds Threads.@threads for n in 1:batches - for oh in 0:out_h-1 - ih0, ih1, h0lambda, h1lambda = compute_source_index_and_lambda(height_scale, oh, in_h, out_h) - for ow in 0:out_w-1 - iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w) - for c in 1:channels - output[c, ow+1, oh+1, n] = - (h0lambda * w0lambda * input[c, iw0+1, ih0+1, n] + # h0 * w0 * i00 - h0lambda * w1lambda * input[c, iw1+1, ih0+1, n] + # h0 * w1 * i01 - h1lambda * w0lambda * input[c, iw0+1, ih1+1, n] + # h1 * w0 * i10 - h1lambda * w1lambda * input[c, iw1+1, ih1+1, n]) # h1 * w1 * i11 - end - end - end - end - return output -end - """ ∇upsample_bilinear(Δ::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); outsize::Union{Nothing,NTuple{2,Integer}}=nothing) where T @@ -168,8 +114,8 @@ function ∇upsample_bilinear_whcn!(Δ::AbstractArray{T,4}, grad_input::Abstract out_w, out_h, _, _ = size(Δ) output_slice_size = out_h * out_w - width_scale = T((in_w - 1) // (out_w - 1)) # area_pixel_compute_scale(in_w, out_w, align_corners, scales[1]); - height_scale = T((in_h - 1) // (out_h - 1)) # area_pixel_compute_scale(in_h, out_h, align_corners, scales[2]); + width_scale = T((in_w - 1) // (out_w - 1)) + height_scale = T((in_h - 1) // (out_h - 1)) @inline idx(c, h, w) = c * in_h * in_w + h * in_w + w + 1 diff --git a/test/upsample.jl b/test/upsample.jl index 3c4c7e6f0..b81cf0314 100644 --- a/test/upsample.jl +++ b/test/upsample.jl @@ -3,12 +3,12 @@ x = cat(x,x; dims=3) x = cat(x,x; dims=4) - y_true = [ 1//1 4//3 5//3 2//1; - 7//5 26//15 31//15 12//5; - 9//5 32//15 37//15 14//5; - 11//5 38//15 43//15 16//5; - 13//5 44//15 49//15 18//5; - 3//1 10//3 11//3 4//1] + y_true = Float32[ 1//1 4//3 5//3 2//1; + 7//5 26//15 31//15 12//5; + 9//5 32//15 37//15 14//5; + 11//5 38//15 43//15 16//5; + 13//5 44//15 49//15 18//5; + 3//1 10//3 11//3 4//1][:,:,:,:] y_true = cat(y_true,y_true; dims=3) y_true = cat(y_true,y_true; dims=4) @@ -17,7 +17,7 @@ @test eltype(y) == Float32 @test y ≈ y_true - gradtest(x->upsample_bilinear(x, (3, 2)), x, atol=1e-4) + gradtest(x->upsample_bilinear(x, (3, 2)), x, atol=1e-3) # works to higher precision for Float64 # this test can be performed again, as soon as the corresponding CUDA functionality is merged From 010cfb91be983b7fb5e256d07195809e63211c57 Mon Sep 17 00:00:00 2001 From: Max Freudenberg Date: Tue, 12 Jan 2021 21:08:12 +0100 Subject: [PATCH 04/21] annotate tests, improve doc --- src/upsample.jl | 9 ++++++--- test/upsample.jl | 8 ++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/upsample.jl b/src/upsample.jl index 580d0af20..5938f32df 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -20,13 +20,16 @@ end upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); outsize::Union{Nothing,NTuple{2,Integer}}=nothing) Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `scale`, -using bilinear interpolation. +using bilinear interpolation. The interpolation is identical to the one used by pytorch. The size of the output is equal to `(scale[1]*S1, scale[2]*S2, S3, S4)`, where `S1, S2, S3, S4 = size(x)`. -The interpolation grid is identical to the one used by open-cv. - +Examples: +``` +upsample_bilinear(x, (2, pi)) # real scaling factors are allowed +upsample_bilinear(x; outsize=(64,64)) # note the semicolon, outsize is a keyword argument +``` Currently only 2d upsampling is supported. """ function upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); outsize::Union{Nothing,NTuple{2,Integer}}=nothing) where T diff --git a/test/upsample.jl b/test/upsample.jl index b81cf0314..8dad6715e 100644 --- a/test/upsample.jl +++ b/test/upsample.jl @@ -3,6 +3,9 @@ x = cat(x,x; dims=3) x = cat(x,x; dims=4) + # this output matches the one of pytorch v1.5.0 + # nn.UpsamplingBilinear2d(scale_factor=(3,2)) + # for above x y_true = Float32[ 1//1 4//3 5//3 2//1; 7//5 26//15 31//15 12//5; 9//5 32//15 37//15 14//5; @@ -19,6 +22,11 @@ gradtest(x->upsample_bilinear(x, (3, 2)), x, atol=1e-3) # works to higher precision for Float64 + # additional grad check, also compliant with pytorch + o = ones(Float32,6,4,1,1) + grad_true = Float32[6 6; 6 6][:,:,:,:] + @test ∇upsample_bilinear(o, (3,2)) ≈ grad_true + # this test can be performed again, as soon as the corresponding CUDA functionality is merged # if CUDA.has_cuda() From 36eaa2509cb3d4bcaf81f8733c1b5841321abb92 Mon Sep 17 00:00:00 2001 From: Max Freudenberg <67329240+maxfreu@users.noreply.github.com> Date: Wed, 13 Jan 2021 10:17:54 +0100 Subject: [PATCH 05/21] update docs --- src/upsample.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/upsample.jl b/src/upsample.jl index 5938f32df..ad367fb7c 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -20,13 +20,13 @@ end upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); outsize::Union{Nothing,NTuple{2,Integer}}=nothing) Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `scale`, -using bilinear interpolation. The interpolation is identical to the one used by pytorch. +using bilinear interpolation. The size of the output is equal to `(scale[1]*S1, scale[2]*S2, S3, S4)`, where `S1, S2, S3, S4 = size(x)`. Examples: -``` +```julia upsample_bilinear(x, (2, pi)) # real scaling factors are allowed upsample_bilinear(x; outsize=(64,64)) # note the semicolon, outsize is a keyword argument ``` From 05ba078bb59974122fdd2e076e6e9775082b0dee Mon Sep 17 00:00:00 2001 From: Max Freudenberg Date: Sat, 16 Jan 2021 20:46:12 +0100 Subject: [PATCH 06/21] change AutoDiff spatial rank test --- test/pooling.jl | 90 ++++++++++++++++++++++++------------------------- 1 file changed, 45 insertions(+), 45 deletions(-) diff --git a/test/pooling.jl b/test/pooling.jl index 197b59078..33f81cd91 100644 --- a/test/pooling.jl +++ b/test/pooling.jl @@ -319,7 +319,7 @@ end # test "true" strided case, see https://github.com/FluxML/NNlib.jl/issues/205 -# obtained with +# obtained with # using FiniteDifferences maxpool_answer_nature = Dict( "rank1" => Dict( @@ -327,62 +327,62 @@ maxpool_answer_nature = Dict( "k2s1p0" => (size = (2,), stride = 1, pad = 0, - - x = reshape([ + + x = reshape([ 0.0299635, 0.233456, 0.596161, 0.161514, 0.0094027 ], 5, 1, 1), # width, channel, batch_size - + dx_maxpool = reshape([ 0.0, 1.0, 2.0, 1.0, 0.0 ], 5, 1, 1), - + dx_meanpool = reshape([ 0.5, 1.0, 1.0, 1.0, 0.5 ], 5, 1, 1),), "k2s1p1" => (size = (2,), stride = 1, pad = 1, - - x = reshape([ + + x = reshape([ 0.0299635, 0.233456, 0.596161, 0.161514, 0.0094027 ], 5, 1, 1), - + dx_maxpool = reshape([ 1.0, 1.0, 2.0, 1.0, 1.0 ], 5, 1, 1), - + dx_meanpool = reshape([ 1.0, 1.0, 1.0, 1.0, 1.0 ], 5, 1, 1),), "k3s1p1" => (size = (3,), stride = 1, pad = 1, - - x = reshape([ + + x = reshape([ 0.0299635, 0.233456, 0.596161, 0.161514, 0.0094027 ], 5, 1, 1), - + dx_maxpool = reshape([ 0.0, 1.0, 3.0, 1.0, 0.0 ], 5, 1, 1), - + dx_meanpool = reshape([ 0.6666666666, 1.0, 1.0, 1.0, 0.6666666666 ], 5, 1, 1),), "k3s2p1" => (size = (3,), stride = 2, pad = 1, - - x = reshape([ + + x = reshape([ 0.0299635, 0.233456, 0.596161, 0.161514, 0.0094027 ], 5, 1, 1), - + dx_maxpool = reshape([ 0.0, 1.0, 1.0, 1.0, 0.0 ], 5, 1, 1), - + dx_meanpool = reshape([ - 0.333333333, + 0.333333333, 0.666666666, 0.333333333, 0.666666666, @@ -395,7 +395,7 @@ maxpool_answer_nature = Dict( stride = 1, pad = 0, - x = reshape([ + x = reshape([ 0.0299635 0.233456 0.596161 0.161514 0.0094027 0.389984 0.235158 0.579525 0.301893 0.561358 0.0830242 0.483759 0.914904 0.253871 0.820061 @@ -408,9 +408,9 @@ maxpool_answer_nature = Dict( 1.0 0.0 0.0 0.0 1.0 0.0 1.0 4.0 0.0 2.0 0.0 1.0 0.0 2.0 0.0 - 0.0 2.0 0.0 0.0 0.0 + 0.0 2.0 0.0 0.0 0.0 ], 5, 5, 1, 1), - + dx_meanpool = reshape([ 0.25 0.5 0.5 0.5 0.25 0.5 1.0 1.0 1.0 0.5 @@ -421,8 +421,8 @@ maxpool_answer_nature = Dict( "k2s1p1" => (size = (2, 2), stride = 1, pad = 1, - - x = reshape([ + + x = reshape([ 0.0299635 0.233456 0.596161 0.161514 0.0094027 0.389984 0.235158 0.579525 0.301893 0.561358 0.0830242 0.483759 0.914904 0.253871 0.820061 @@ -437,7 +437,7 @@ maxpool_answer_nature = Dict( 1.0 1.0 0.0 2.0 0.0 2.0 4.0 1.0 0.0 3.0 ], 5, 5, 1, 1), - + dx_meanpool = reshape([ 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 @@ -449,7 +449,7 @@ maxpool_answer_nature = Dict( stride = 1, pad = 1, - x = reshape([ + x = reshape([ 0.0299635 0.233456 0.596161 0.161514 0.0094027 0.389984 0.235158 0.579525 0.301893 0.561358 0.0830242 0.483759 0.914904 0.253871 0.820061 @@ -464,7 +464,7 @@ maxpool_answer_nature = Dict( 0.0 1.0 0.0 3.0 0.0 0.0 3.0 0.0 0.0 0.0 ], 5, 5, 1, 1), - + dx_meanpool = reshape([ 0.444444 0.666667 0.666667 0.666667 0.444444 0.666667 1.0 1.0 1.0 0.666667 @@ -476,7 +476,7 @@ maxpool_answer_nature = Dict( stride = 2, pad = 1, - x = reshape([ + x = reshape([ 0.0299635 0.233456 0.596161 0.161514 0.0094027 0.389984 0.235158 0.579525 0.301893 0.561358 0.0830242 0.483759 0.914904 0.253871 0.820061 @@ -491,7 +491,7 @@ maxpool_answer_nature = Dict( 0.0 1.0 0.0 2.0 0.0 0.0 1.0 0.0 0.0 0.0 ], 5, 5, 1, 1), - + dx_meanpool = reshape([ 0.111111 0.222222 0.111111 0.222222 0.111111 0.222222 0.444444 0.222222 0.444444 0.222222 @@ -505,7 +505,7 @@ maxpool_answer_nature = Dict( "k2s1p0" => (size = (2, 2, 2), stride = 1, pad = 0, - + x = reshape(cat([ 0.82584 0.416818 0.92668 0.471931 0.798798 0.131608 0.344556 0.79681 @@ -524,7 +524,7 @@ maxpool_answer_nature = Dict( 0.640221 0.28811 0.129229 0.97571 0.953795 0.1316 0.94538 0.705337 ],dims=3), 4,4,3,1,1), - + dx_maxpool = reshape(cat([ 1.0 0.0 2.0 0.0 1.0 0.0 0.0 0.0 @@ -543,7 +543,7 @@ maxpool_answer_nature = Dict( 0.0 0.0 0.0 2.0 1.0 0.0 1.0 0.0 ],dims=3), 4,4,3,1,1), - + dx_meanpool = reshape(cat([ 0.125 0.25 0.25 0.125 0.25 0.5 0.5 0.25 @@ -565,7 +565,7 @@ maxpool_answer_nature = Dict( "k2s1p1" => (size = (2, 2, 2), stride = 1, pad = 1, - + x = reshape(cat([ 0.82584 0.416818 0.92668 0.471931 0.798798 0.131608 0.344556 0.79681 @@ -584,7 +584,7 @@ maxpool_answer_nature = Dict( 0.640221 0.28811 0.129229 0.97571 0.953795 0.1316 0.94538 0.705337 ],dims=3), 4,4,3,1,1), - + dx_maxpool = reshape(cat([ 8.0 0.0 8.0 2.0 4.0 0.0 1.0 4.0 @@ -603,7 +603,7 @@ maxpool_answer_nature = Dict( 3.0 0.0 0.0 8.0 8.0 0.0 6.0 1.0 ],dims=3), 4,4,3,1,1), - + dx_meanpool = reshape(cat([ 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 @@ -625,7 +625,7 @@ maxpool_answer_nature = Dict( "k3s1p1" => (size = (3, 3, 2), stride = 1, pad = 1, - + x = reshape(cat([ 0.82584 0.416818 0.92668 0.471931 0.798798 0.131608 0.344556 0.79681 @@ -644,7 +644,7 @@ maxpool_answer_nature = Dict( 0.640221 0.28811 0.129229 0.97571 0.953795 0.1316 0.94538 0.705337 ],dims=3), 4,4,3,1,1), - + dx_maxpool = reshape(cat([ 4.0 0.0 12.0 0.0 3.0 0.0 0.0 2.0 @@ -663,7 +663,7 @@ maxpool_answer_nature = Dict( 0.0 0.0 0.0 12.0 8.0 0.0 0.0 0.0 ],dims=3), 4,4,3,1,1), - + dx_meanpool = reshape(cat([ 0.444444 0.666667 0.666667 0.444444 0.666667 1.0 1.0 0.666667 @@ -685,7 +685,7 @@ maxpool_answer_nature = Dict( "k3s2p1" => (size = (3, 3, 2), stride = 2, pad = 1, - + x = reshape(cat([ 0.82584 0.416818 0.92668 0.471931 0.798798 0.131608 0.344556 0.79681 @@ -704,7 +704,7 @@ maxpool_answer_nature = Dict( 0.640221 0.28811 0.129229 0.97571 0.953795 0.1316 0.94538 0.705337 ],dims=3), 4,4,3,1,1), - + dx_maxpool = reshape(cat([ 1.0 0.0 1.0 0.0 1.0 0.0 0.0 1.0 @@ -723,7 +723,7 @@ maxpool_answer_nature = Dict( 0.0 0.0 0.0 1.0 1.0 0.0 0.0 0.0 ],dims=3), 4,4,3,1,1), - + dx_meanpool = reshape(cat([ 0.0555556 0.111111 0.0555556 0.0555556 0.111111 0.222222 0.111111 0.111111 @@ -750,7 +750,7 @@ maxpool_answer_nature = Dict( # issue #205 function check(config, T) # CHECK DEFAULT - pdims = PoolDims(config.x, config.size; stride=config.stride, padding=config.pad) + pdims = PoolDims(config.x, config.size; stride=config.stride, padding=config.pad) x = T.(config.x) y_maxpool = NNlib.maxpool(x, pdims) y_meanpool = NNlib.meanpool(x, pdims) @@ -813,15 +813,15 @@ end @testset "AutoDiff: spatial_rank=$spatial_rank" for spatial_rank in (1, 2) x = rand(rng, repeat([10], spatial_rank)..., 3, 2) pdims = PoolDims(x, 2) - gradtest(x -> maxpool(x, pdims), x; broken=spatial_rank <= 2) + gradtest(x -> maxpool(x, pdims), x, broken=spatial_rank <= 0) # was <= 2 before gradtest(x -> meanpool(x, pdims), x) - gradtest(x -> sum(maxpool(x, pdims)), x, broken=spatial_rank == 2) + gradtest(x -> sum(maxpool(x, pdims)), x, broken=spatial_rank == 0) # was == 2 before gradtest(x -> sum(meanpool(x, pdims)), x) #https://github.com/FluxML/NNlib.jl/issues/188 k = ntuple(_ -> 2, spatial_rank) # Kernel size of pool in ntuple format - gradtest(x -> maxpool(x, k), x; broken=spatial_rank <= 2) + gradtest(x -> maxpool(x, k), x, broken=spatial_rank <= 0) # was <= 2 before gradtest(x -> meanpool(x, k), x) - gradtest(x -> sum(maxpool(x, k)), x, broken=spatial_rank == 2) + gradtest(x -> sum(maxpool(x, k)), x, broken=spatial_rank == 0) # was == 2 before gradtest(x -> sum(meanpool(x, k)), x) end From 2763d9d31c6242f0f2189c6e793373386f255371 Mon Sep 17 00:00:00 2001 From: Max Freudenberg Date: Mon, 18 Jan 2021 13:14:37 +0100 Subject: [PATCH 07/21] introduce single-number arg and test for real upscaling factors --- src/upsample.jl | 4 ++++ test/upsample.jl | 11 ++++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/upsample.jl b/src/upsample.jl index ad367fb7c..8a9da5136 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -44,6 +44,8 @@ function upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); o return upsample_bilinear_whcn!(y, x) end +upsample_bilinear(x, scale::Real; outsize=nothing) = upsample_bilinear(x, (scale,scale); outsize=outsize) + # this is the core function which works on arrays of arbitrary size # the implementation is a translation of https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/UpSampleMoreKernel.cpp # which implements open-cv style linear interpolation / upsampling @@ -108,6 +110,8 @@ function ∇upsample_bilinear(Δ::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1 return ∇upsample_bilinear_whcn!(Δ, dx) end +∇upsample_bilinear(Δ, scale::Real; outsize=nothing) = ∇upsample_bilinear(Δ, (scale, scale); outsize=outsize) + function ∇upsample_bilinear_whcn!(Δ::AbstractArray{T,4}, grad_input::AbstractArray{T,4}) where T size(grad_input)[3:4] == size(Δ)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") in_w, in_h, channels, batches = size(grad_input) diff --git a/test/upsample.jl b/test/upsample.jl index 8dad6715e..08fb81206 100644 --- a/test/upsample.jl +++ b/test/upsample.jl @@ -4,7 +4,7 @@ x = cat(x,x; dims=4) # this output matches the one of pytorch v1.5.0 - # nn.UpsamplingBilinear2d(scale_factor=(3,2)) + # nn.UpsamplingBilinear2d(scale_factor=(3,2), align_corners=True) # for above x y_true = Float32[ 1//1 4//3 5//3 2//1; 7//5 26//15 31//15 12//5; @@ -27,6 +27,15 @@ grad_true = Float32[6 6; 6 6][:,:,:,:] @test ∇upsample_bilinear(o, (3,2)) ≈ grad_true + y_true_2 = Rational{Int}[1//1 5//4 6//4 7//4 2//1; + 3//2 7//4 8//4 9//4 5//2; + 4//2 9//4 10//4 11//4 6//2; + 5//2 11//4 12//4 13//4 7//2; + 3//1 13//4 14//4 15//4 4//1][:,:,:,:] + + # check for real-valued single-number argument and type stability for rationals + upsample_bilinear(x, 2.5) == y_true_2 + # this test can be performed again, as soon as the corresponding CUDA functionality is merged # if CUDA.has_cuda() From 079e71bd26e56c4fb82a74c5c12756e513a8d219 Mon Sep 17 00:00:00 2001 From: Max Freudenberg Date: Thu, 21 Jan 2021 12:35:06 +0100 Subject: [PATCH 08/21] outsize -> size, improve grad test --- src/upsample.jl | 32 ++++++++++++++++---------------- test/upsample.jl | 4 ++-- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/upsample.jl b/src/upsample.jl index 8a9da5136..77d6b7c8d 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -17,7 +17,7 @@ export upsample_bilinear, ∇upsample_bilinear, pixel_shuffle end """ - upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); outsize::Union{Nothing,NTuple{2,Integer}}=nothing) + upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); size::Union{Nothing,NTuple{2,Integer}}=nothing) Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `scale`, using bilinear interpolation. @@ -28,23 +28,23 @@ The size of the output is equal to Examples: ```julia upsample_bilinear(x, (2, pi)) # real scaling factors are allowed -upsample_bilinear(x; outsize=(64,64)) # note the semicolon, outsize is a keyword argument +upsample_bilinear(x; size=(64,64)) # note the semicolon, size is a keyword argument ``` Currently only 2d upsampling is supported. """ -function upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); outsize::Union{Nothing,NTuple{2,Integer}}=nothing) where T - w,h,c,n = size(x) - if outsize===nothing +function upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); size::Union{Nothing,NTuple{2,Integer}}=nothing) where T + w,h,c,n = Base.Base.size(x) + if size===nothing out_w = floor(Int, scale[1]*w) out_h = floor(Int, scale[2]*h) else - out_w, out_h = outsize + out_w, out_h = size end y = Array{T,4}(undef, out_w, out_h, c, n) return upsample_bilinear_whcn!(y, x) end -upsample_bilinear(x, scale::Real; outsize=nothing) = upsample_bilinear(x, (scale,scale); outsize=outsize) +upsample_bilinear(x, scale::Real; size=nothing) = upsample_bilinear(x, (scale,scale); size=size) # this is the core function which works on arrays of arbitrary size # the implementation is a translation of https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/UpSampleMoreKernel.cpp @@ -90,7 +90,7 @@ function upsample_bilinear_whcn!(output::AbstractArray{T,4}, input::AbstractArra end """ - ∇upsample_bilinear(Δ::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); outsize::Union{Nothing,NTuple{2,Integer}}=nothing) where T + ∇upsample_bilinear(Δ::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); size::Union{Nothing,NTuple{2,Integer}}=nothing) where T # Arguments - `Δ`: incoming gradient array that has been upsampled using the upsample factors in `scale` @@ -98,19 +98,19 @@ end # Outputs - `dx`: downsampled version of `Δ` """ -function ∇upsample_bilinear(Δ::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); outsize::Union{Nothing,NTuple{2,Integer}}=nothing) where T - w,h,c,n = size(Δ) - if outsize===nothing +function ∇upsample_bilinear(Δ::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); size::Union{Nothing,NTuple{2,Integer}}=nothing) where T + w,h,c,n = Base.size(Δ) + if size===nothing out_w = ceil(Int, w/scale[1]) out_h = ceil(Int, h/scale[2]) else - out_w, out_h = outsize + out_w, out_h = size end dx = zeros(T, out_w, out_h, c, n) return ∇upsample_bilinear_whcn!(Δ, dx) end -∇upsample_bilinear(Δ, scale::Real; outsize=nothing) = ∇upsample_bilinear(Δ, (scale, scale); outsize=outsize) +∇upsample_bilinear(Δ, scale::Real; size=nothing) = ∇upsample_bilinear(Δ, (scale, scale); size=size) function ∇upsample_bilinear_whcn!(Δ::AbstractArray{T,4}, grad_input::AbstractArray{T,4}) where T size(grad_input)[3:4] == size(Δ)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") @@ -143,10 +143,10 @@ function ∇upsample_bilinear_whcn!(Δ::AbstractArray{T,4}, grad_input::Abstract return grad_input end -function ChainRulesCore.rrule(::typeof(upsample_bilinear), x, scale; outsize=nothing) - Ω = upsample_bilinear(x, scale; outsize=outsize) +function ChainRulesCore.rrule(::typeof(upsample_bilinear), x, scale; size=nothing) + Ω = upsample_bilinear(x, scale; size=size) function upsample_bilinear_pullback(Δ) - (NO_FIELDS, ∇upsample_bilinear(Δ, scale; outsize=(size(x,1),size(x,2))), DoesNotExist()) + (NO_FIELDS, ∇upsample_bilinear(Δ, scale; size=(Base.size(x,1),Base.size(x,2))), DoesNotExist()) end return Ω, upsample_bilinear_pullback end diff --git a/test/upsample.jl b/test/upsample.jl index 08fb81206..7713dee88 100644 --- a/test/upsample.jl +++ b/test/upsample.jl @@ -23,8 +23,8 @@ gradtest(x->upsample_bilinear(x, (3, 2)), x, atol=1e-3) # works to higher precision for Float64 # additional grad check, also compliant with pytorch - o = ones(Float32,6,4,1,1) - grad_true = Float32[6 6; 6 6][:,:,:,:] + o = ones(Float32,6,4,2,1) + grad_true = 6*ones(Float32,2,2,2,1) @test ∇upsample_bilinear(o, (3,2)) ≈ grad_true y_true_2 = Rational{Int}[1//1 5//4 6//4 7//4 2//1; From e2b3051297ea76fc2d10599dc187aa7ea3dfb9e8 Mon Sep 17 00:00:00 2001 From: Max Freudenberg Date: Fri, 22 Jan 2021 11:30:13 +0100 Subject: [PATCH 09/21] error when both args given --- src/upsample.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/upsample.jl b/src/upsample.jl index 77d6b7c8d..e13977b03 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -34,7 +34,10 @@ Currently only 2d upsampling is supported. """ function upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); size::Union{Nothing,NTuple{2,Integer}}=nothing) where T w,h,c,n = Base.Base.size(x) - if size===nothing + if scale != (1,1) && size !== nothing + error("Please provide either scale or size, not both. Got scale=$scale and size=$size.") + end + if size === nothing out_w = floor(Int, scale[1]*w) out_h = floor(Int, scale[2]*h) else @@ -100,6 +103,9 @@ end """ function ∇upsample_bilinear(Δ::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); size::Union{Nothing,NTuple{2,Integer}}=nothing) where T w,h,c,n = Base.size(Δ) + if scale != (1,1) && size !== nothing + error("Please provide either scale or size, not both. Got scale=$scale and size=$size.") + end if size===nothing out_w = ceil(Int, w/scale[1]) out_h = ceil(Int, h/scale[2]) From a0185fda9c6ae0b903a3ddc8b39bdb63977c361d Mon Sep 17 00:00:00 2001 From: Max Freudenberg Date: Fri, 22 Jan 2021 11:52:53 +0100 Subject: [PATCH 10/21] add simple integer support --- src/upsample.jl | 6 ++++++ test/upsample.jl | 9 +++++++++ 2 files changed, 15 insertions(+) diff --git a/src/upsample.jl b/src/upsample.jl index e13977b03..2b1373752 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -49,6 +49,12 @@ end upsample_bilinear(x, scale::Real; size=nothing) = upsample_bilinear(x, (scale,scale); size=size) +function upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); size=nothing) where T<:Integer + y = float.(x) + res = upsample_bilinear(y, scale; size=size) + return round.(T, res) +end + # this is the core function which works on arrays of arbitrary size # the implementation is a translation of https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/UpSampleMoreKernel.cpp # which implements open-cv style linear interpolation / upsampling diff --git a/test/upsample.jl b/test/upsample.jl index 7713dee88..a056ed3d1 100644 --- a/test/upsample.jl +++ b/test/upsample.jl @@ -36,6 +36,15 @@ # check for real-valued single-number argument and type stability for rationals upsample_bilinear(x, 2.5) == y_true_2 + # check Integer support for forward pass + # grads are always assumed to be floats, so no extension there + x = UInt8[1 3; 3 5][:,:,:,:] + y_true_int = UInt8[1 2 3; 2 3 4; 3 4 5][:,:,:,:] + y = upsample_bilinear(x, 1.5) + + @test eltype(y) == UInt8 + @test y == y_true_int + # this test can be performed again, as soon as the corresponding CUDA functionality is merged # if CUDA.has_cuda() From e95b27d0fb082ff97c5c68aec885e252befc9b87 Mon Sep 17 00:00:00 2001 From: Max Freudenberg Date: Fri, 22 Jan 2021 13:27:56 +0100 Subject: [PATCH 11/21] tiny fix --- src/upsample.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/upsample.jl b/src/upsample.jl index e4a36cc0d..bcc12ac05 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -97,7 +97,7 @@ upsample_bilinear(x; size=(64,64)) # note the semicolon, size is a keyword argum Currently only 2d upsampling is supported. """ function upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); size::Union{Nothing,NTuple{2,Integer}}=nothing) where T - w,h,c,n = Base.Base.size(x) + w,h,c,n = Base.size(x) if scale != (1,1) && size !== nothing error("Please provide either scale or size, not both. Got scale=$scale and size=$size.") end From 1bfe25da595e9db377bfd35b7d9876bcec62ca9c Mon Sep 17 00:00:00 2001 From: Max Freudenberg Date: Fri, 22 Jan 2021 15:46:43 +0100 Subject: [PATCH 12/21] fix rrule --- src/upsample.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/upsample.jl b/src/upsample.jl index bcc12ac05..b58b0f5c2 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -222,7 +222,7 @@ end function ChainRulesCore.rrule(::typeof(upsample_bilinear), x, scale; size=nothing) Ω = upsample_bilinear(x, scale; size=size) function upsample_bilinear_pullback(Δ) - (NO_FIELDS, ∇upsample_bilinear(Δ, scale; size=(Base.size(x,1),Base.size(x,2))), DoesNotExist()) + (NO_FIELDS, ∇upsample_bilinear(Δ; size=(Base.size(x,1),Base.size(x,2))), DoesNotExist()) end return Ω, upsample_bilinear_pullback end From 975caef2645f1b6ad70a166e4c26d666cc77387b Mon Sep 17 00:00:00 2001 From: Max Freudenberg <67329240+maxfreu@users.noreply.github.com> Date: Fri, 22 Jan 2021 16:10:07 +0100 Subject: [PATCH 13/21] add default scale in rrule Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- src/upsample.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/upsample.jl b/src/upsample.jl index b58b0f5c2..58500aec8 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -219,7 +219,7 @@ function ∇upsample_bilinear_whcn!(Δ::AbstractArray{T,4}, grad_input::Abstract return grad_input end -function ChainRulesCore.rrule(::typeof(upsample_bilinear), x, scale; size=nothing) +function ChainRulesCore.rrule(::typeof(upsample_bilinear), x, scale=(1,1); size=nothing) Ω = upsample_bilinear(x, scale; size=size) function upsample_bilinear_pullback(Δ) (NO_FIELDS, ∇upsample_bilinear(Δ; size=(Base.size(x,1),Base.size(x,2))), DoesNotExist()) From f74209742b4e9105e709559368ec57c416080ec6 Mon Sep 17 00:00:00 2001 From: Max Freudenberg <67329240+maxfreu@users.noreply.github.com> Date: Tue, 26 Jan 2021 12:21:33 +0100 Subject: [PATCH 14/21] thin-out gradient API Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- src/upsample.jl | 13 ++----------- test/upsample.jl | 2 +- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/src/upsample.jl b/src/upsample.jl index 58500aec8..dc2325478 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -171,22 +171,13 @@ end # Outputs - `dx`: downsampled version of `Δ` """ -function ∇upsample_bilinear(Δ::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); size::Union{Nothing,NTuple{2,Integer}}=nothing) where T - w,h,c,n = Base.size(Δ) - if scale != (1,1) && size !== nothing - error("Please provide either scale or size, not both. Got scale=$scale and size=$size.") - end - if size===nothing - out_w = ceil(Int, w/scale[1]) - out_h = ceil(Int, h/scale[2]) - else - out_w, out_h = size +function ∇upsample_bilinear(Δ::AbstractArray{T,4}; size::NTuple{2,Integer}) where T + out_w, out_h = size end dx = zeros(T, out_w, out_h, c, n) return ∇upsample_bilinear_whcn!(Δ, dx) end -∇upsample_bilinear(Δ, scale::Real; size=nothing) = ∇upsample_bilinear(Δ, (scale, scale); size=size) function ∇upsample_bilinear_whcn!(Δ::AbstractArray{T,4}, grad_input::AbstractArray{T,4}) where T size(grad_input)[3:4] == size(Δ)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") diff --git a/test/upsample.jl b/test/upsample.jl index 82a00ec1b..7dec97789 100644 --- a/test/upsample.jl +++ b/test/upsample.jl @@ -39,7 +39,7 @@ end # additional grad check, also compliant with pytorch o = ones(Float32,6,4,2,1) grad_true = 6*ones(Float32,2,2,2,1) - @test ∇upsample_bilinear(o, (3,2)) ≈ grad_true + @test ∇upsample_bilinear(o; size = (2,2)) ≈ grad_true y_true_2 = Rational{Int}[1//1 5//4 6//4 7//4 2//1; 3//2 7//4 8//4 9//4 5//2; From ce8fa4fa6f6e3f9faf86330ab2810b855a15e29d Mon Sep 17 00:00:00 2001 From: Max Freudenberg Date: Tue, 26 Jan 2021 14:04:23 +0100 Subject: [PATCH 15/21] tiny fix, fix docs --- src/upsample.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/upsample.jl b/src/upsample.jl index dc2325478..aa1e82d92 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -163,22 +163,22 @@ function upsample_bilinear_whcn!(output::AbstractArray{T,4}, input::AbstractArra end """ - ∇upsample_bilinear(Δ::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); size::Union{Nothing,NTuple{2,Integer}}=nothing) where T + ∇upsample_bilinear(Δ::AbstractArray{T,4}; size::Union{Nothing,NTuple{2,Integer}}=nothing) where T # Arguments -- `Δ`: incoming gradient array that has been upsampled using the upsample factors in `scale` +- `Δ`: Incoming gradient array, backpropagated from downstream layers +- `size`: Lateral (W,H) size of the image upsampled in the first place # Outputs -- `dx`: downsampled version of `Δ` +- `dx`: Downsampled version of `Δ` """ function ∇upsample_bilinear(Δ::AbstractArray{T,4}; size::NTuple{2,Integer}) where T + _, _, c, n = Base.size(Δ) out_w, out_h = size - end dx = zeros(T, out_w, out_h, c, n) return ∇upsample_bilinear_whcn!(Δ, dx) end - function ∇upsample_bilinear_whcn!(Δ::AbstractArray{T,4}, grad_input::AbstractArray{T,4}) where T size(grad_input)[3:4] == size(Δ)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") in_w, in_h, channels, batches = size(grad_input) From 95c1e8a771cb4d52598113c45d5d62c15272de92 Mon Sep 17 00:00:00 2001 From: Max Freudenberg Date: Wed, 27 Jan 2021 11:44:55 +0100 Subject: [PATCH 16/21] rename kernel, change arg order --- src/upsample.jl | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/upsample.jl b/src/upsample.jl index aa1e82d92..a631ca8d7 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -108,7 +108,7 @@ function upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); s out_w, out_h = size end y = Array{T,4}(undef, out_w, out_h, c, n) - return upsample_bilinear_whcn!(y, x) + return upsample_bilinear_whcn_kernel!(y, x) end upsample_bilinear(x, scale::Real; size=nothing) = upsample_bilinear(x, (scale,scale); size=size) @@ -128,7 +128,7 @@ end # - RGB types could be supported via reinterpreting # - integer types need to be converted to Float and back # - rationals work, but are slow -function upsample_bilinear_whcn!(output::AbstractArray{T,4}, input::AbstractArray{T,4}) where T +function upsample_bilinear_whcn_kernel!(output::AbstractArray{T,4}, input::AbstractArray{T,4}) where T if size(input) == size(output) return input end @@ -176,12 +176,16 @@ function ∇upsample_bilinear(Δ::AbstractArray{T,4}; size::NTuple{2,Integer}) w _, _, c, n = Base.size(Δ) out_w, out_h = size dx = zeros(T, out_w, out_h, c, n) - return ∇upsample_bilinear_whcn!(Δ, dx) + return ∇upsample_bilinear_whcn_kernel!(dx, Δ) end -function ∇upsample_bilinear_whcn!(Δ::AbstractArray{T,4}, grad_input::AbstractArray{T,4}) where T - size(grad_input)[3:4] == size(Δ)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") - in_w, in_h, channels, batches = size(grad_input) +function ∇upsample_bilinear_whcn_kernel!(dx::AbstractArray{T,4}, Δ::AbstractArray{T,4}) where T + if size(dx) == size(Δ) + return Δ + end + + size(dx)[3:4] == size(Δ)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") + in_w, in_h, channels, batches = size(dx) # treat batch and channel dimension as one for better parallelization granularity channels *= batches @@ -200,14 +204,14 @@ function ∇upsample_bilinear_whcn!(Δ::AbstractArray{T,4}, grad_input::Abstract iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w) output_offset = c * output_slice_size + oh * out_w + ow + 1 Δ_value = Δ[output_offset] - grad_input[idx(c, ih0, iw0)] += h0lambda * w0lambda * Δ_value # i00 - grad_input[idx(c, ih0, iw1)] += h0lambda * w1lambda * Δ_value # i01 - grad_input[idx(c, ih1, iw0)] += h1lambda * w0lambda * Δ_value # i10 - grad_input[idx(c, ih1, iw1)] += h1lambda * w1lambda * Δ_value # i11 + dx[idx(c, ih0, iw0)] += h0lambda * w0lambda * Δ_value # i00 + dx[idx(c, ih0, iw1)] += h0lambda * w1lambda * Δ_value # i01 + dx[idx(c, ih1, iw0)] += h1lambda * w0lambda * Δ_value # i10 + dx[idx(c, ih1, iw1)] += h1lambda * w1lambda * Δ_value # i11 end end end - return grad_input + return dx end function ChainRulesCore.rrule(::typeof(upsample_bilinear), x, scale=(1,1); size=nothing) From eee022851fa6b4ca9c3f693d79456e82e2999737 Mon Sep 17 00:00:00 2001 From: Max Freudenberg Date: Wed, 27 Jan 2021 12:24:17 +0100 Subject: [PATCH 17/21] =?UTF-8?q?introduce=20(=E2=88=87)upsample=5Fbilinea?= =?UTF-8?q?r!=20indirection?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/upsample.jl | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/upsample.jl b/src/upsample.jl index a631ca8d7..e7ecdfeea 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -107,8 +107,8 @@ function upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); s else out_w, out_h = size end - y = Array{T,4}(undef, out_w, out_h, c, n) - return upsample_bilinear_whcn_kernel!(y, x) + y = similar(x, T, out_w, out_h, c, n) + return upsample_bilinear!(y, x) end upsample_bilinear(x, scale::Real; size=nothing) = upsample_bilinear(x, (scale,scale); size=size) @@ -119,6 +119,8 @@ function upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); s return round.(T, res) end +upsample_bilinear!(y::AbstractArray{<:Any,4}, x::AbstractArray{<:Any,4}) = upsample_bilinear_whcn_kernel!(y,x) + # this is the core function which works on arrays of arbitrary size # the implementation is a translation of https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/UpSampleMoreKernel.cpp # which implements open-cv style linear interpolation / upsampling @@ -175,15 +177,17 @@ end function ∇upsample_bilinear(Δ::AbstractArray{T,4}; size::NTuple{2,Integer}) where T _, _, c, n = Base.size(Δ) out_w, out_h = size - dx = zeros(T, out_w, out_h, c, n) - return ∇upsample_bilinear_whcn_kernel!(dx, Δ) + dx = zero(similar(Δ, T, out_w, out_h, c, n)) + return ∇upsample_bilinear!(dx, Δ) end +∇upsample_bilinear!(dx::AbstractArray{<:Any,4}, Δ::AbstractArray{<:Any,4}) = ∇upsample_bilinear_whcn_kernel!(dx, Δ) + function ∇upsample_bilinear_whcn_kernel!(dx::AbstractArray{T,4}, Δ::AbstractArray{T,4}) where T if size(dx) == size(Δ) return Δ end - + size(dx)[3:4] == size(Δ)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") in_w, in_h, channels, batches = size(dx) From 266addabb9e3650242828817047fa56105729062 Mon Sep 17 00:00:00 2001 From: Max Freudenberg Date: Fri, 29 Jan 2021 11:44:13 +0100 Subject: [PATCH 18/21] finalize internal API (?) --- src/upsample.jl | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/src/upsample.jl b/src/upsample.jl index e7ecdfeea..a1f5d083f 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -107,8 +107,11 @@ function upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); s else out_w, out_h = size end + if (w,h) == (out_w, out_h) + return x + end y = similar(x, T, out_w, out_h, c, n) - return upsample_bilinear!(y, x) + return upsample_bilinear_whcn!(y, x) end upsample_bilinear(x, scale::Real; size=nothing) = upsample_bilinear(x, (scale,scale); size=size) @@ -119,8 +122,6 @@ function upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); s return round.(T, res) end -upsample_bilinear!(y::AbstractArray{<:Any,4}, x::AbstractArray{<:Any,4}) = upsample_bilinear_whcn_kernel!(y,x) - # this is the core function which works on arrays of arbitrary size # the implementation is a translation of https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/UpSampleMoreKernel.cpp # which implements open-cv style linear interpolation / upsampling @@ -130,10 +131,7 @@ upsample_bilinear!(y::AbstractArray{<:Any,4}, x::AbstractArray{<:Any,4}) = upsam # - RGB types could be supported via reinterpreting # - integer types need to be converted to Float and back # - rationals work, but are slow -function upsample_bilinear_whcn_kernel!(output::AbstractArray{T,4}, input::AbstractArray{T,4}) where T - if size(input) == size(output) - return input - end +function upsample_bilinear_whcn!(output::AbstractArray{T,4}, input::AbstractArray{T,4}) where T size(input)[3:4] == size(output)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") in_w, in_h, channels, batches = size(input) # treat batch and channel dimension as one for better parallelization granularity @@ -175,19 +173,16 @@ end - `dx`: Downsampled version of `Δ` """ function ∇upsample_bilinear(Δ::AbstractArray{T,4}; size::NTuple{2,Integer}) where T - _, _, c, n = Base.size(Δ) + w, h, c, n = Base.size(Δ) out_w, out_h = size - dx = zero(similar(Δ, T, out_w, out_h, c, n)) - return ∇upsample_bilinear!(dx, Δ) -end - -∇upsample_bilinear!(dx::AbstractArray{<:Any,4}, Δ::AbstractArray{<:Any,4}) = ∇upsample_bilinear_whcn_kernel!(dx, Δ) - -function ∇upsample_bilinear_whcn_kernel!(dx::AbstractArray{T,4}, Δ::AbstractArray{T,4}) where T - if size(dx) == size(Δ) + if (w,h) == (out_w, out_h) return Δ end + dx = zero(similar(Δ, T, out_w, out_h, c, n)) + return ∇upsample_bilinear_whcn!(dx, Δ) +end +function ∇upsample_bilinear_whcn!(dx::AbstractArray{T,4}, Δ::AbstractArray{T,4}) where T size(dx)[3:4] == size(Δ)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))") in_w, in_h, channels, batches = size(dx) From 1827d417dc81ade8040631f08c47f9fca805ec24 Mon Sep 17 00:00:00 2001 From: Max Freudenberg <67329240+maxfreu@users.noreply.github.com> Date: Mon, 1 Feb 2021 11:35:24 +0100 Subject: [PATCH 19/21] improve docs, clean up grad signature Co-authored-by: Carlo Lucibello --- src/upsample.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/upsample.jl b/src/upsample.jl index a1f5d083f..c506c8876 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -89,12 +89,13 @@ using bilinear interpolation. The size of the output is equal to `(scale[1]*S1, scale[2]*S2, S3, S4)`, where `S1, S2, S3, S4 = size(x)`. +As an alternative to using `scale`, the resulting image `size` can be directly specified with a keyword argument. + Examples: ```julia upsample_bilinear(x, (2, pi)) # real scaling factors are allowed -upsample_bilinear(x; size=(64,64)) # note the semicolon, size is a keyword argument +upsample_bilinear(x; size=(64,64)) # specify ouput size ``` -Currently only 2d upsampling is supported. """ function upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); size::Union{Nothing,NTuple{2,Integer}}=nothing) where T w,h,c,n = Base.size(x) @@ -163,7 +164,7 @@ function upsample_bilinear_whcn!(output::AbstractArray{T,4}, input::AbstractArra end """ - ∇upsample_bilinear(Δ::AbstractArray{T,4}; size::Union{Nothing,NTuple{2,Integer}}=nothing) where T + ∇upsample_bilinear(Δ::AbstractArray{T,4}; size::NTuple{2,Integer}) where T # Arguments - `Δ`: Incoming gradient array, backpropagated from downstream layers From 7240ecf10a0446acfed03cc07205a17b95378ec9 Mon Sep 17 00:00:00 2001 From: Max Freudenberg Date: Mon, 1 Feb 2021 12:20:38 +0100 Subject: [PATCH 20/21] split function into two methods for size/scale --- src/upsample.jl | 40 ++++++++++++++++++---------------------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/src/upsample.jl b/src/upsample.jl index c506c8876..2fb55fab3 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -81,45 +81,41 @@ end end """ - upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); size::Union{Nothing,NTuple{2,Integer}}=nothing) + upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}) + upsample_bilinear(x::AbstractArray{T,4}; size::NTuple{2,Integer}) Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `scale`, -using bilinear interpolation. +using bilinear interpolation. As an alternative to using `scale`, the resulting image `size` +can be directly specified with a keyword argument. The size of the output is equal to `(scale[1]*S1, scale[2]*S2, S3, S4)`, where `S1, S2, S3, S4 = size(x)`. -As an alternative to using `scale`, the resulting image `size` can be directly specified with a keyword argument. - Examples: ```julia upsample_bilinear(x, (2, pi)) # real scaling factors are allowed upsample_bilinear(x; size=(64,64)) # specify ouput size ``` """ -function upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); size::Union{Nothing,NTuple{2,Integer}}=nothing) where T +function upsample_bilinear(x::AbstractArray{<:Any,4}, scale::NTuple{2,Real}) + outsize = ntuple(i -> floor(Int, scale[i] * Base.size(x, i)), 2) + return upsample_bilinear(x; size=outsize) +end + +upsample_bilinear(x, scale::Real) = upsample_bilinear(x, (scale,scale)) + +function upsample_bilinear(x::AbstractArray{T,4}; size::NTuple{2,Integer}) where T w,h,c,n = Base.size(x) - if scale != (1,1) && size !== nothing - error("Please provide either scale or size, not both. Got scale=$scale and size=$size.") - end - if size === nothing - out_w = floor(Int, scale[1]*w) - out_h = floor(Int, scale[2]*h) - else - out_w, out_h = size - end - if (w,h) == (out_w, out_h) + if (w,h) == size return x end - y = similar(x, T, out_w, out_h, c, n) + y = similar(x, T, size..., c, n) return upsample_bilinear_whcn!(y, x) end -upsample_bilinear(x, scale::Real; size=nothing) = upsample_bilinear(x, (scale,scale); size=size) - -function upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); size=nothing) where T<:Integer +function upsample_bilinear(x::AbstractArray{T,4}; size::NTuple{2,Integer}) where T<:Integer y = float.(x) - res = upsample_bilinear(y, scale; size=size) + res = upsample_bilinear(y; size=size) return round.(T, res) end @@ -214,8 +210,8 @@ function ∇upsample_bilinear_whcn!(dx::AbstractArray{T,4}, Δ::AbstractArray{T, return dx end -function ChainRulesCore.rrule(::typeof(upsample_bilinear), x, scale=(1,1); size=nothing) - Ω = upsample_bilinear(x, scale; size=size) +function ChainRulesCore.rrule(::typeof(upsample_bilinear), x; size) + Ω = upsample_bilinear(x; size=size) function upsample_bilinear_pullback(Δ) (NO_FIELDS, ∇upsample_bilinear(Δ; size=(Base.size(x,1),Base.size(x,2))), DoesNotExist()) end From 8e8e0dd784ac1d4c94e3aefba9b53e705bd62144 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 2 Feb 2021 10:13:19 +0100 Subject: [PATCH 21/21] fix rrule --- src/upsample.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/upsample.jl b/src/upsample.jl index 2fb55fab3..a67c156aa 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -213,7 +213,7 @@ end function ChainRulesCore.rrule(::typeof(upsample_bilinear), x; size) Ω = upsample_bilinear(x; size=size) function upsample_bilinear_pullback(Δ) - (NO_FIELDS, ∇upsample_bilinear(Δ; size=(Base.size(x,1),Base.size(x,2))), DoesNotExist()) + (NO_FIELDS, ∇upsample_bilinear(Δ; size=(Base.size(x,1),Base.size(x,2)))) end return Ω, upsample_bilinear_pullback end