Skip to content

Commit 3623541

Browse files
Merge pull request #266 from maxfreu/master
improve bilinear upsampling
2 parents d9aaaf7 + 8e8e0dd commit 3623541

File tree

3 files changed

+224
-260
lines changed

3 files changed

+224
-260
lines changed

src/upsample.jl

Lines changed: 126 additions & 195 deletions
Original file line numberDiff line numberDiff line change
@@ -64,235 +64,166 @@ function ChainRulesCore.rrule(::typeof(upsample_nearest), x::AbstractArray, s::T
6464
return Ω, upsample_nearest_pullback
6565
end
6666

67-
"""
68-
upsample_bilinear(x::AbstractArray{<:Number,4}, k::NTuple{2,Int})
69-
70-
Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `k`,
71-
using bilinear interpolation.
72-
73-
The size of the output is equal to
74-
`(k[1]*S1, k[2]*S2, S3, S4)`, where `S1, S2, S3, S4 = size(x)`.
75-
76-
The interpolation grid is identical to the one used by `imresize` from `Images.jl`.
77-
78-
Currently only 2d upsampling is supported.
79-
"""
80-
function upsample_bilinear(x::AbstractArray{T,4}, k::NTuple{2,Int}) where T
81-
# This function is gpu friendly
82-
83-
imgsize = size(x)
84-
newsize = get_newsize(imgsize, k)
85-
86-
# Get linear interpolation lower- and upper index, and weights
87-
ilow1, ihigh1, wdiff1 = get_inds_and_ws(x, imgsize[1], newsize[1], 1)
88-
ilow2, ihigh2, wdiff2 = get_inds_and_ws(x, imgsize[2], newsize[2], 2)
89-
90-
# Adjust the upper interpolation indices of the second dimension
91-
ihigh2_r = adjoint_of_idx(ilow2)[ihigh2]
92-
93-
@inbounds y = @view(x[ilow1,ilow2,:,:]) .* (1 .- wdiff1) .+ @view(x[ihigh1,ilow2,:,:]) .* wdiff1
94-
@inbounds y .= y .* (1 .- wdiff2) .+ y[:,ihigh2_r,:,:] .* wdiff2
95-
# @inbounds y = y .* (1 .- wdiff2) .+ @view(y[:,ihigh2_r,:,:]) .* wdiff2 # equivalent to line above
96-
return y
97-
end
98-
99-
function get_inds_and_ws(x::T, n::Int, m::Int, dim::Int) where T <: AbstractArray
100-
# Creates interpolation grid for resampling.
101-
# Creates the same grid as used in Image.jl `imresize`.
102-
step = n // m
103-
offset = (n + 1)//2 - step//2 - step * (m//2 - 1)
104-
xq = clamp.(range(offset, step=step, length=m), 1, n)
105-
106-
# Creates interpolation lower and upper indices, and broadcastable weights
107-
ilow = floor.(Int, xq)
108-
ihigh = ceil.(Int, xq)
109-
sizew = ntuple(i-> i == dim ? length(xq) : 1, ndims(x))
110-
wdiff = convert(T, reshape(xq .- ilow, sizew)) # wdiff possibly lives on gpu
111-
return ilow, ihigh, wdiff
67+
# utility function
68+
@inline function compute_source_index_and_lambda(
69+
ratio, # 0 < ratio < 1
70+
output_index,
71+
input_size,
72+
output_size
73+
)
74+
real_input_index = ratio*output_index
75+
input_index0 = floor(Int, real_input_index) # typecast to int was here in C++
76+
offset = (input_index0 < input_size - 1) ? 1 : 0
77+
input_index1 = input_index0 + offset
78+
lambda1 = real_input_index - input_index0
79+
lambda0 = 1 - lambda1
80+
return input_index0, input_index1, lambda0, lambda1
11281
end
11382

11483
"""
115-
adjoint_of_idx(idx::Vector{<:Integer})
116-
117-
# Arguments
118-
- `idx`: a vector of indices from which you want the adjoint.
84+
upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real})
85+
upsample_bilinear(x::AbstractArray{T,4}; size::NTuple{2,Integer})
11986
120-
# Outputs
121-
-`idx_adjoint`: index that inverses the operation `x[idx]`.
87+
Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `scale`,
88+
using bilinear interpolation. As an alternative to using `scale`, the resulting image `size`
89+
can be directly specified with a keyword argument.
12290
123-
# Explanation
124-
Determines the adjoint of the vector of indices `idx`, based on the following assumptions:
125-
* `idx[1] == 1`
126-
* `all(d in [0,1] for d in diff(idx))`
127-
The adjoint of `idx` can be seen as an inverse operation such that:
91+
The size of the output is equal to
92+
`(scale[1]*S1, scale[2]*S2, S3, S4)`, where `S1, S2, S3, S4 = size(x)`.
12893
94+
Examples:
12995
```julia
130-
x = [1, 2, 3, 4, 5]
131-
idx = [1, 2, 2, 3, 4, 4, 5]
132-
idx_adjoint = adjoint_of_idx(idx)
133-
@assert x[idx][idx_adjoint] == x
96+
upsample_bilinear(x, (2, pi)) # real scaling factors are allowed
97+
upsample_bilinear(x; size=(64,64)) # specify ouput size
13498
```
135-
The above holds as long as `idx` contains every index in `x`.
13699
"""
137-
function adjoint_of_idx(idx::Vector{Int})
138-
d = trues(length(idx))
139-
d[2:end] .= diff(idx)
140-
idx_adjoint = findall(d)
141-
return idx_adjoint
142-
end
143-
144-
function get_newsize(sz, k)
145-
return ntuple(i -> i <= length(k) ? sz[i]*k[i] : sz[i], length(sz))
100+
function upsample_bilinear(x::AbstractArray{<:Any,4}, scale::NTuple{2,Real})
101+
outsize = ntuple(i -> floor(Int, scale[i] * Base.size(x, i)), 2)
102+
return upsample_bilinear(x; size=outsize)
146103
end
147104

105+
upsample_bilinear(x, scale::Real) = upsample_bilinear(x, (scale,scale))
148106

149-
"""
150-
∇upsample_bilinear(Δ::AbstractArray{<:Number,4}, k::NTuple{2,Int})
151-
152-
# Arguments
153-
- `Δ`: array that has been upsampled using the upsample factors in `k`
154-
155-
# Outputs
156-
- `dx`: downsampled version of `Δ`
157-
158-
# Explanation
159-
160-
Custom adjoint for [`upsample_bilinear`](@ref).
161-
The adjoint of upsampling is a downsampling operation, which
162-
in this implementation is performed using `NNlib.conv` in combination with a downsampling kernel based on the
163-
upsampling factors. Because of the zero-padding during convolution, the values at the boundary are polluted by edge-effects,
164-
which have been corrected for manually.
165-
"""
166-
function ∇upsample_bilinear::AbstractArray{<:Number, 4}, k::NTuple{2,Int})
167-
# This function is gpu friendly
168-
169-
# Be more efficient on some corner cases
170-
if size(Δ, 1) == k[1]
171-
Δ = sum(Δ, dims=1)
172-
k = (1, k[2])
173-
end
174-
if size(Δ, 2) == k[2]
175-
Δ = sum(Δ, dims=2)
176-
k = (k[1], 1)
107+
function upsample_bilinear(x::AbstractArray{T,4}; size::NTuple{2,Integer}) where T
108+
w,h,c,n = Base.size(x)
109+
if (w,h) == size
110+
return x
177111
end
178-
if (size(Δ, 1) == 1) && (size(Δ, 2) == 1)
179-
dx = Δ
180-
return dx
181-
end
182-
183-
n_chan, n_batch = size(Δ, 3), size(Δ, 4)
184-
185-
kern1 = get_downsamplekernel(Δ, k[1])
186-
kern2 = get_downsamplekernel(Δ, k[2])
187-
kern = kern1 * kern2'
188-
189-
pad = (floor(Int, k[1]//2), floor(Int, k[2]//2))
190-
stride = k
191-
192-
weight = similar(Δ, eltype(Δ), (size(kern)..., n_chan, n_chan))
193-
weight .= 0
194-
for i in 1:n_chan
195-
weight[:,:,i,i] .= kern
196-
end
197-
# weight = cat(fill(kern, n_chan)..., dims=(3,4)) # slow
198-
dx = conv(Δ, weight, pad=pad, stride=stride)
112+
y = similar(x, T, size..., c, n)
113+
return upsample_bilinear_whcn!(y, x)
114+
end
199115

200-
# Still have to fix edge effects due to zero-padding of convolution,
201-
# TODO: Could be circumvented by having padding that just extrapolates the value at the first/last index
202-
# nextras = tuple((Int.(floor(factor//2)) for factor in k)...)
203-
nextras = (floor(Int, k[1]//2), floor(Int, k[2]//2))
116+
function upsample_bilinear(x::AbstractArray{T,4}; size::NTuple{2,Integer}) where T<:Integer
117+
y = float.(x)
118+
res = upsample_bilinear(y; size=size)
119+
return round.(T, res)
120+
end
204121

205-
# First dimension edge-effect correction
206-
if nextras[1] > 0
207-
kern1 = kern[1:nextras[1],:]
208-
pad1 = (0, pad[2])
209-
stride1 = (1, stride[2])
210-
weight1 = similar(Δ, eltype(Δ), (size(kern1)..., n_chan, n_chan))
211-
weight1 .= 0
212-
for i in 1:n_chan
213-
weight1[:,:,i,i] .= kern1
122+
# this is the core function which works on arrays of arbitrary size
123+
# the implementation is a translation of https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/UpSampleMoreKernel.cpp
124+
# which implements open-cv style linear interpolation / upsampling
125+
# for simplicity, corners are aligned and all logic for other behaviour has been stripped
126+
# - whcn because there is also a cwhn implementation
127+
# - the function is parallelized using @threads
128+
# - RGB types could be supported via reinterpreting
129+
# - integer types need to be converted to Float and back
130+
# - rationals work, but are slow
131+
function upsample_bilinear_whcn!(output::AbstractArray{T,4}, input::AbstractArray{T,4}) where T
132+
size(input)[3:4] == size(output)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
133+
in_w, in_h, channels, batches = size(input)
134+
# treat batch and channel dimension as one for better parallelization granularity
135+
channels *= batches
136+
out_w, out_h, _, _ = size(output)
137+
output_slice_size = out_h * out_w
138+
139+
# T() and // so that we can handle rationals (super slow)
140+
width_scale = T((in_w - 1) // (out_w - 1))
141+
height_scale = T((in_h - 1) // (out_h - 1))
142+
143+
@inline idx(c, h, w) = c * in_h * in_w + h * in_w + w + 1
144+
145+
@inbounds Threads.@threads for c in 0:channels-1
146+
for oh in 0:out_h-1
147+
ih0, ih1, h0lambda, h1lambda = compute_source_index_and_lambda(height_scale, oh, in_h, out_h)
148+
for ow in 0:out_w-1
149+
iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w)
150+
output_offset = c * output_slice_size + oh * out_w + ow + 1
151+
output[output_offset] =
152+
(h0lambda * w0lambda * input[idx(c, ih0, iw0)] + # h0 * w0 * i00
153+
h0lambda * w1lambda * input[idx(c, ih0, iw1)] + # h0 * w1 * i01
154+
h1lambda * w0lambda * input[idx(c, ih1, iw0)] + # h1 * w0 * i10
155+
h1lambda * w1lambda * input[idx(c, ih1, iw1)]) # h1 * w1 * i11
156+
end
214157
end
215-
# weight1 = cat(fill(kern1, n_chan)..., dims=(3,4)) # slow
216-
dx[[1],:,:,:] .+= conv(Δ[1:nextras[1],:,:,:], weight1, pad=pad1, stride=stride1)
217-
weight1 .= weight1[end:-1:1,:,:,:]
218-
dx[[end],:,:,:] .+= conv(Δ[end-nextras[1]+1:end,:,:,:], weight1, pad=pad1, stride=stride1)
219-
220-
## Conv with views is not dispatched to CUDA.conv
221-
# dx[[1],:,:,:] .+= conv(@view(Δ[1:nextras[1],:,:,:]), weight1, pad=pad1, stride=stride1)
222-
# weight1 .= @view(weight1[end:-1:1,:,:,:])
223-
# dx[[end],:,:,:] .+= conv(@view(Δ[end-nextras[1]+1:end,:,:,:]), weight1, pad=pad1, stride=stride1)
224158
end
159+
return output
160+
end
225161

226-
# Second dimension edge-effect correction
227-
if nextras[2] > 0
228-
kern2 = kern[:,1:nextras[2]]
229-
pad2 = (pad[1], 0)
230-
stride2 = (stride[1], 1)
231-
weight2 = similar(Δ, eltype(Δ), (size(kern2)..., n_chan, n_chan))
232-
weight2 .= 0
233-
for i in 1:n_chan
234-
weight2[:,:,i,i] .= kern2
235-
end
236-
# weight2 = cat(fill(kern2, n_chan)..., dims=(3,4)) # slow
237-
238-
yy = conv(Δ[:,1:nextras[2],:,:], weight2, pad=pad2, stride=stride2)
239-
dx[:,[1],:,:] .+= conv(Δ[:,1:nextras[2],:,:], weight2, pad=pad2, stride=stride2)
240-
weight2 .= weight2[:,end:-1:1,:,:]
241-
dx[:,[end],:,:] .+= conv(Δ[:,end-nextras[2]+1:end,:,:], weight2, pad=pad2, stride=stride2)
162+
"""
163+
∇upsample_bilinear(Δ::AbstractArray{T,4}; size::NTuple{2,Integer}) where T
242164
243-
## Conv with views is not dispatched to CUDA.conv
244-
# yy = conv(@view(Δ[:,1:nextras[2],:,:]), weight2, pad=pad2, stride=stride2)
245-
# dx[:,[1],:,:] .+= conv(@view(Δ[:,1:nextras[2],:,:]), weight2, pad=pad2, stride=stride2)
246-
# weight2 .= @view(weight2[:,end:-1:1,:,:])
247-
# dx[:,[end],:,:] .+= conv(@view(Δ[:,end-nextras[2]+1:end,:,:]), weight2, pad=pad2, stride=stride2)
248-
end
165+
# Arguments
166+
- `Δ`: Incoming gradient array, backpropagated from downstream layers
167+
- `size`: Lateral (W,H) size of the image upsampled in the first place
249168
250-
## Finally fix four corners if needed
251-
n1, n2 = nextras
252-
if (n1 > 0) & (n2 > 0)
253-
dx[1,1,:,:] .+= sum(kern[1:n1,1:n2] .* @view(Δ[1:n1,1:n2,:,:]), dims=(1,2))[1,1,:,:]
254-
dx[1,end,:,:] .+= sum(kern[1:n1,end-n2+1:end] .* @view(Δ[1:n1,end-n2+1:end,:,:]), dims=(1,2))[1,1,:,:]
255-
dx[end,end,:,:] .+= sum(kern[end-n1+1:end,end-n2+1:end] .* @view(Δ[end-n1+1:end,end-n2+1:end,:,:]), dims=(1,2))[1,1,:,:]
256-
dx[end,1,:,:] .+= sum(kern[end-n1+1:end,1:n2] .* @view(Δ[end-n1+1:end,1:n2,:,:]), dims=(1,2))[1,1,:,:]
169+
# Outputs
170+
- `dx`: Downsampled version of `Δ`
171+
"""
172+
function ∇upsample_bilinear::AbstractArray{T,4}; size::NTuple{2,Integer}) where T
173+
w, h, c, n = Base.size(Δ)
174+
out_w, out_h = size
175+
if (w,h) == (out_w, out_h)
176+
return Δ
257177
end
258-
259-
return dx
178+
dx = zero(similar(Δ, T, out_w, out_h, c, n))
179+
return ∇upsample_bilinear_whcn!(dx, Δ)
260180
end
261181

262-
# `n` upsample factor for which a downsample kernel will be determined.
263-
# Δ is given in case of necessity of gpu conversion
264-
function get_downsamplekernel(Δ, n::Int)
265-
step = 1//n
266-
if n % 2 == 0
267-
start = step//2
268-
upward = collect(start:step:1//1)
269-
kernel = [upward; reverse(upward)]
270-
else
271-
start = step
272-
upward = collect(start:step:1//1)
273-
kernel = [upward; reverse(upward[1:end-1])]
182+
function ∇upsample_bilinear_whcn!(dx::AbstractArray{T,4}, Δ::AbstractArray{T,4}) where T
183+
size(dx)[3:4] == size(Δ)[3:4] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
184+
in_w, in_h, channels, batches = size(dx)
185+
186+
# treat batch and channel dimension as one for better parallelization granularity
187+
channels *= batches
188+
out_w, out_h, _, _ = size(Δ)
189+
output_slice_size = out_h * out_w
190+
191+
width_scale = T((in_w - 1) // (out_w - 1))
192+
height_scale = T((in_h - 1) // (out_h - 1))
193+
194+
@inline idx(c, h, w) = c * in_h * in_w + h * in_w + w + 1
195+
196+
@inbounds Threads.@threads for c in 0:channels-1
197+
for oh in 0:out_h-1
198+
ih0, ih1, h0lambda, h1lambda = compute_source_index_and_lambda(height_scale, oh, in_h, out_h)
199+
for ow in 0:out_w-1
200+
iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w)
201+
output_offset = c * output_slice_size + oh * out_w + ow + 1
202+
Δ_value = Δ[output_offset]
203+
dx[idx(c, ih0, iw0)] += h0lambda * w0lambda * Δ_value # i00
204+
dx[idx(c, ih0, iw1)] += h0lambda * w1lambda * Δ_value # i01
205+
dx[idx(c, ih1, iw0)] += h1lambda * w0lambda * Δ_value # i10
206+
dx[idx(c, ih1, iw1)] += h1lambda * w1lambda * Δ_value # i11
207+
end
208+
end
274209
end
275-
# TODO there must be a more convenient way to send to gpu
276-
kernel = convert(typeof(Δ), reshape(kernel, length(kernel), 1, 1, 1))
277-
kernel = dropdims(kernel, dims=(2,3,4))
278-
return kernel
210+
return dx
279211
end
280212

281-
function ChainRulesCore.rrule(::typeof(upsample_bilinear), x, k)
282-
Ω = upsample_bilinear(x, k)
213+
function ChainRulesCore.rrule(::typeof(upsample_bilinear), x; size)
214+
Ω = upsample_bilinear(x; size=size)
283215
function upsample_bilinear_pullback(Δ)
284-
(NO_FIELDS, ∇upsample_bilinear, k), DoesNotExist())
216+
(NO_FIELDS, ∇upsample_bilinear; size=(Base.size(x,1),Base.size(x,2))))
285217
end
286218
return Ω, upsample_bilinear_pullback
287219
end
288220

289-
290221
"""
291222
pixel_shuffle(x, r)
292-
223+
293224
Pixel shuffling operation. `r` is the upscale factor for shuffling.
294225
The operation converts an input of size [W,H,r²C,N] to size [rW,rH,C,N]
295-
Used extensively in super-resolution networks to upsample
226+
Used extensively in super-resolution networks to upsample
296227
towards high resolution features.
297228
298229
Reference : https://arxiv.org/pdf/1609.05158.pdf
@@ -301,7 +232,7 @@ function pixel_shuffle(x::AbstractArray, r::Integer)
301232
@assert ndims(x) > 2
302233
d = ndims(x) - 2
303234
sizein = size(x)[1:d]
304-
cin, n = size(x, d+1), size(x, d+2)
235+
cin, n = size(x, d+1), size(x, d+2)
305236
@assert cin % r^d == 0
306237
cout = cin ÷ r^d
307238
# x = reshape(x, sizein..., fill(r, d)..., cout, n) # bug https://github.com/FluxML/Zygote.jl/issues/866

0 commit comments

Comments
 (0)