Skip to content

improve bilinear upsampling #266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Feb 5, 2021
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
327 changes: 133 additions & 194 deletions src/upsample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,235 +64,174 @@ function ChainRulesCore.rrule(::typeof(upsample_nearest), x::AbstractArray, s::T
return Ω, upsample_nearest_pullback
end

"""
upsample_bilinear(x::AbstractArray{<:Number,4}, k::NTuple{2,Int})

Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `k`,
using bilinear interpolation.

The size of the output is equal to
`(k[1]*S1, k[2]*S2, S3, S4)`, where `S1, S2, S3, S4 = size(x)`.

The interpolation grid is identical to the one used by `imresize` from `Images.jl`.

Currently only 2d upsampling is supported.
"""
function upsample_bilinear(x::AbstractArray{T,4}, k::NTuple{2,Int}) where T
# This function is gpu friendly

imgsize = size(x)
newsize = get_newsize(imgsize, k)

# Get linear interpolation lower- and upper index, and weights
ilow1, ihigh1, wdiff1 = get_inds_and_ws(x, imgsize[1], newsize[1], 1)
ilow2, ihigh2, wdiff2 = get_inds_and_ws(x, imgsize[2], newsize[2], 2)

# Adjust the upper interpolation indices of the second dimension
ihigh2_r = adjoint_of_idx(ilow2)[ihigh2]

@inbounds y = @view(x[ilow1,ilow2,:,:]) .* (1 .- wdiff1) .+ @view(x[ihigh1,ilow2,:,:]) .* wdiff1
@inbounds y .= y .* (1 .- wdiff2) .+ y[:,ihigh2_r,:,:] .* wdiff2
# @inbounds y = y .* (1 .- wdiff2) .+ @view(y[:,ihigh2_r,:,:]) .* wdiff2 # equivalent to line above
return y
end

function get_inds_and_ws(x::T, n::Int, m::Int, dim::Int) where T <: AbstractArray
# Creates interpolation grid for resampling.
# Creates the same grid as used in Image.jl `imresize`.
step = n // m
offset = (n + 1)//2 - step//2 - step * (m//2 - 1)
xq = clamp.(range(offset, step=step, length=m), 1, n)

# Creates interpolation lower and upper indices, and broadcastable weights
ilow = floor.(Int, xq)
ihigh = ceil.(Int, xq)
sizew = ntuple(i-> i == dim ? length(xq) : 1, ndims(x))
wdiff = convert(T, reshape(xq .- ilow, sizew)) # wdiff possibly lives on gpu
return ilow, ihigh, wdiff
# utility function
@inline function compute_source_index_and_lambda(
ratio, # 0 < ratio < 1
output_index,
input_size,
output_size
)
real_input_index = ratio*output_index
input_index0 = floor(Int, real_input_index) # typecast to int was here in C++
offset = (input_index0 < input_size - 1) ? 1 : 0
input_index1 = input_index0 + offset
lambda1 = real_input_index - input_index0
lambda0 = 1 - lambda1
return input_index0, input_index1, lambda0, lambda1
end

"""
adjoint_of_idx(idx::Vector{<:Integer})

# Arguments
- `idx`: a vector of indices from which you want the adjoint.
upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); size::Union{Nothing,NTuple{2,Integer}}=nothing)

# Outputs
-`idx_adjoint`: index that inverses the operation `x[idx]`.
Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `scale`,
using bilinear interpolation.

# Explanation
Determines the adjoint of the vector of indices `idx`, based on the following assumptions:
* `idx[1] == 1`
* `all(d in [0,1] for d in diff(idx))`
The adjoint of `idx` can be seen as an inverse operation such that:
The size of the output is equal to
`(scale[1]*S1, scale[2]*S2, S3, S4)`, where `S1, S2, S3, S4 = size(x)`.

Examples:
```julia
x = [1, 2, 3, 4, 5]
idx = [1, 2, 2, 3, 4, 4, 5]
idx_adjoint = adjoint_of_idx(idx)
@assert x[idx][idx_adjoint] == x
upsample_bilinear(x, (2, pi)) # real scaling factors are allowed
upsample_bilinear(x; size=(64,64)) # note the semicolon, size is a keyword argument
```
The above holds as long as `idx` contains every index in `x`.
Currently only 2d upsampling is supported.
"""
function adjoint_of_idx(idx::Vector{Int})
d = trues(length(idx))
d[2:end] .= diff(idx)
idx_adjoint = findall(d)
return idx_adjoint
function upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); size::Union{Nothing,NTuple{2,Integer}}=nothing) where T
w,h,c,n = Base.size(x)
if scale != (1,1) && size !== nothing
error("Please provide either scale or size, not both. Got scale=$scale and size=$size.")
end
if size === nothing
out_w = floor(Int, scale[1]*w)
out_h = floor(Int, scale[2]*h)
else
out_w, out_h = size
end
y = similar(x, T, out_w, out_h, c, n)
return upsample_bilinear!(y, x)
end

function get_newsize(sz, k)
return ntuple(i -> i <= length(k) ? sz[i]*k[i] : sz[i], length(sz))
upsample_bilinear(x, scale::Real; size=nothing) = upsample_bilinear(x, (scale,scale); size=size)

function upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real}=(1,1); size=nothing) where T<:Integer
y = float.(x)
res = upsample_bilinear(y, scale; size=size)
return round.(T, res)
end

upsample_bilinear!(y::AbstractArray{<:Any,4}, x::AbstractArray{<:Any,4}) = upsample_bilinear_whcn_kernel!(y,x)

# this is the core function which works on arrays of arbitrary size
# the implementation is a translation of https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/UpSampleMoreKernel.cpp
# which implements open-cv style linear interpolation / upsampling
# for simplicity, corners are aligned and all logic for other behaviour has been stripped
# - whcn because there is also a cwhn implementation
# - the function is parallelized using @threads
# - RGB types could be supported via reinterpreting
# - integer types need to be converted to Float and back
# - rationals work, but are slow
function upsample_bilinear_whcn_kernel!(output::AbstractArray{T,4}, input::AbstractArray{T,4}) where T
if size(input) == size(output)
return input
end
size(input)[3:4] == size(output)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
in_w, in_h, channels, batches = size(input)
# treat batch and channel dimension as one for better parallelization granularity
channels *= batches
out_w, out_h, _, _ = size(output)
output_slice_size = out_h * out_w

# T() and // so that we can handle rationals (super slow)
width_scale = T((in_w - 1) // (out_w - 1))
height_scale = T((in_h - 1) // (out_h - 1))

@inline idx(c, h, w) = c * in_h * in_w + h * in_w + w + 1

@inbounds Threads.@threads for c in 0:channels-1
for oh in 0:out_h-1
ih0, ih1, h0lambda, h1lambda = compute_source_index_and_lambda(height_scale, oh, in_h, out_h)
for ow in 0:out_w-1
iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w)
output_offset = c * output_slice_size + oh * out_w + ow + 1
output[output_offset] =
(h0lambda * w0lambda * input[idx(c, ih0, iw0)] + # h0 * w0 * i00
h0lambda * w1lambda * input[idx(c, ih0, iw1)] + # h0 * w1 * i01
h1lambda * w0lambda * input[idx(c, ih1, iw0)] + # h1 * w0 * i10
h1lambda * w1lambda * input[idx(c, ih1, iw1)]) # h1 * w1 * i11
end
end
end
return output
end

"""
∇upsample_bilinear(Δ::AbstractArray{<:Number,4}, k::NTuple{2,Int})
∇upsample_bilinear(Δ::AbstractArray{T,4}; size::Union{Nothing,NTuple{2,Integer}}=nothing) where T

# Arguments
- `Δ`: array that has been upsampled using the upsample factors in `k`
- `Δ`: Incoming gradient array, backpropagated from downstream layers
- `size`: Lateral (W,H) size of the image upsampled in the first place

# Outputs
- `dx`: downsampled version of `Δ`

# Explanation

Custom adjoint for [`upsample_bilinear`](@ref).
The adjoint of upsampling is a downsampling operation, which
in this implementation is performed using `NNlib.conv` in combination with a downsampling kernel based on the
upsampling factors. Because of the zero-padding during convolution, the values at the boundary are polluted by edge-effects,
which have been corrected for manually.
- `dx`: Downsampled version of `Δ`
"""
function ∇upsample_bilinear(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int})
# This function is gpu friendly

# Be more efficient on some corner cases
if size(Δ, 1) == k[1]
Δ = sum(Δ, dims=1)
k = (1, k[2])
end
if size(Δ, 2) == k[2]
Δ = sum(Δ, dims=2)
k = (k[1], 1)
end
if (size(Δ, 1) == 1) && (size(Δ, 2) == 1)
dx = Δ
return dx
end
function ∇upsample_bilinear(Δ::AbstractArray{T,4}; size::NTuple{2,Integer}) where T
_, _, c, n = Base.size(Δ)
out_w, out_h = size
dx = zero(similar(Δ, T, out_w, out_h, c, n))
return ∇upsample_bilinear!(dx, Δ)
end

n_chan, n_batch = size(Δ, 3), size(Δ, 4)

kern1 = get_downsamplekernel(Δ, k[1])
kern2 = get_downsamplekernel(Δ, k[2])
kern = kern1 * kern2'

pad = (floor(Int, k[1]//2), floor(Int, k[2]//2))
stride = k

weight = similar(Δ, eltype(Δ), (size(kern)..., n_chan, n_chan))
weight .= 0
for i in 1:n_chan
weight[:,:,i,i] .= kern
end
# weight = cat(fill(kern, n_chan)..., dims=(3,4)) # slow
dx = conv(Δ, weight, pad=pad, stride=stride)

# Still have to fix edge effects due to zero-padding of convolution,
# TODO: Could be circumvented by having padding that just extrapolates the value at the first/last index
# nextras = tuple((Int.(floor(factor//2)) for factor in k)...)
nextras = (floor(Int, k[1]//2), floor(Int, k[2]//2))

# First dimension edge-effect correction
if nextras[1] > 0
kern1 = kern[1:nextras[1],:]
pad1 = (0, pad[2])
stride1 = (1, stride[2])
weight1 = similar(Δ, eltype(Δ), (size(kern1)..., n_chan, n_chan))
weight1 .= 0
for i in 1:n_chan
weight1[:,:,i,i] .= kern1
end
# weight1 = cat(fill(kern1, n_chan)..., dims=(3,4)) # slow
dx[[1],:,:,:] .+= conv(Δ[1:nextras[1],:,:,:], weight1, pad=pad1, stride=stride1)
weight1 .= weight1[end:-1:1,:,:,:]
dx[[end],:,:,:] .+= conv(Δ[end-nextras[1]+1:end,:,:,:], weight1, pad=pad1, stride=stride1)

## Conv with views is not dispatched to CUDA.conv
# dx[[1],:,:,:] .+= conv(@view(Δ[1:nextras[1],:,:,:]), weight1, pad=pad1, stride=stride1)
# weight1 .= @view(weight1[end:-1:1,:,:,:])
# dx[[end],:,:,:] .+= conv(@view(Δ[end-nextras[1]+1:end,:,:,:]), weight1, pad=pad1, stride=stride1)
end
∇upsample_bilinear!(dx::AbstractArray{<:Any,4}, Δ::AbstractArray{<:Any,4}) = ∇upsample_bilinear_whcn_kernel!(dx, Δ)

# Second dimension edge-effect correction
if nextras[2] > 0
kern2 = kern[:,1:nextras[2]]
pad2 = (pad[1], 0)
stride2 = (stride[1], 1)
weight2 = similar(Δ, eltype(Δ), (size(kern2)..., n_chan, n_chan))
weight2 .= 0
for i in 1:n_chan
weight2[:,:,i,i] .= kern2
end
# weight2 = cat(fill(kern2, n_chan)..., dims=(3,4)) # slow

yy = conv(Δ[:,1:nextras[2],:,:], weight2, pad=pad2, stride=stride2)
dx[:,[1],:,:] .+= conv(Δ[:,1:nextras[2],:,:], weight2, pad=pad2, stride=stride2)
weight2 .= weight2[:,end:-1:1,:,:]
dx[:,[end],:,:] .+= conv(Δ[:,end-nextras[2]+1:end,:,:], weight2, pad=pad2, stride=stride2)

## Conv with views is not dispatched to CUDA.conv
# yy = conv(@view(Δ[:,1:nextras[2],:,:]), weight2, pad=pad2, stride=stride2)
# dx[:,[1],:,:] .+= conv(@view(Δ[:,1:nextras[2],:,:]), weight2, pad=pad2, stride=stride2)
# weight2 .= @view(weight2[:,end:-1:1,:,:])
# dx[:,[end],:,:] .+= conv(@view(Δ[:,end-nextras[2]+1:end,:,:]), weight2, pad=pad2, stride=stride2)
function ∇upsample_bilinear_whcn_kernel!(dx::AbstractArray{T,4}, Δ::AbstractArray{T,4}) where T
if size(dx) == size(Δ)
return Δ
end

## Finally fix four corners if needed
n1, n2 = nextras
if (n1 > 0) & (n2 > 0)
dx[1,1,:,:] .+= sum(kern[1:n1,1:n2] .* @view(Δ[1:n1,1:n2,:,:]), dims=(1,2))[1,1,:,:]
dx[1,end,:,:] .+= sum(kern[1:n1,end-n2+1:end] .* @view(Δ[1:n1,end-n2+1:end,:,:]), dims=(1,2))[1,1,:,:]
dx[end,end,:,:] .+= sum(kern[end-n1+1:end,end-n2+1:end] .* @view(Δ[end-n1+1:end,end-n2+1:end,:,:]), dims=(1,2))[1,1,:,:]
dx[end,1,:,:] .+= sum(kern[end-n1+1:end,1:n2] .* @view(Δ[end-n1+1:end,1:n2,:,:]), dims=(1,2))[1,1,:,:]
size(dx)[3:4] == size(Δ)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
in_w, in_h, channels, batches = size(dx)

# treat batch and channel dimension as one for better parallelization granularity
channels *= batches
out_w, out_h, _, _ = size(Δ)
output_slice_size = out_h * out_w

width_scale = T((in_w - 1) // (out_w - 1))
height_scale = T((in_h - 1) // (out_h - 1))

@inline idx(c, h, w) = c * in_h * in_w + h * in_w + w + 1

@inbounds Threads.@threads for c in 0:channels-1
for oh in 0:out_h-1
ih0, ih1, h0lambda, h1lambda = compute_source_index_and_lambda(height_scale, oh, in_h, out_h)
for ow in 0:out_w-1
iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w)
output_offset = c * output_slice_size + oh * out_w + ow + 1
Δ_value = Δ[output_offset]
dx[idx(c, ih0, iw0)] += h0lambda * w0lambda * Δ_value # i00
dx[idx(c, ih0, iw1)] += h0lambda * w1lambda * Δ_value # i01
dx[idx(c, ih1, iw0)] += h1lambda * w0lambda * Δ_value # i10
dx[idx(c, ih1, iw1)] += h1lambda * w1lambda * Δ_value # i11
end
end
end

return dx
end

# `n` upsample factor for which a downsample kernel will be determined.
# Δ is given in case of necessity of gpu conversion
function get_downsamplekernel(Δ, n::Int)
step = 1//n
if n % 2 == 0
start = step//2
upward = collect(start:step:1//1)
kernel = [upward; reverse(upward)]
else
start = step
upward = collect(start:step:1//1)
kernel = [upward; reverse(upward[1:end-1])]
end
# TODO there must be a more convenient way to send to gpu
kernel = convert(typeof(Δ), reshape(kernel, length(kernel), 1, 1, 1))
kernel = dropdims(kernel, dims=(2,3,4))
return kernel
end

function ChainRulesCore.rrule(::typeof(upsample_bilinear), x, k)
Ω = upsample_bilinear(x, k)
function ChainRulesCore.rrule(::typeof(upsample_bilinear), x, scale=(1,1); size=nothing)
Ω = upsample_bilinear(x, scale; size=size)
function upsample_bilinear_pullback(Δ)
(NO_FIELDS, ∇upsample_bilinear(Δ, k), DoesNotExist())
(NO_FIELDS, ∇upsample_bilinear(Δ; size=(Base.size(x,1),Base.size(x,2))), DoesNotExist())
end
return Ω, upsample_bilinear_pullback
end


"""
pixel_shuffle(x, r)

Pixel shuffling operation. `r` is the upscale factor for shuffling.
The operation converts an input of size [W,H,r²C,N] to size [rW,rH,C,N]
Used extensively in super-resolution networks to upsample
Used extensively in super-resolution networks to upsample
towards high resolution features.

Reference : https://arxiv.org/pdf/1609.05158.pdf
Expand All @@ -301,7 +240,7 @@ function pixel_shuffle(x::AbstractArray, r::Integer)
@assert ndims(x) > 2
d = ndims(x) - 2
sizein = size(x)[1:d]
cin, n = size(x, d+1), size(x, d+2)
cin, n = size(x, d+1), size(x, d+2)
@assert cin % r^d == 0
cout = cin ÷ r^d
# x = reshape(x, sizein..., fill(r, d)..., cout, n) # bug https://github.com/FluxML/Zygote.jl/issues/866
Expand Down
Loading