Skip to content

Commit acb6916

Browse files
Merge pull request #262 from FluxML/cl/upsample
add bilinear upsampling
2 parents 090a0ec + 147f03d commit acb6916

File tree

6 files changed

+285
-30
lines changed

6 files changed

+285
-30
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Requires = "0.5, 1.0"
1515
julia = "1.3"
1616

1717
[extras]
18+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1819
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
1920
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2021
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -23,4 +24,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2324
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2425

2526
[targets]
26-
test = ["ChainRulesTestUtils", "FiniteDifferences", "Random", "StableRNGs", "Test", "Zygote"]
27+
test = ["ChainRulesTestUtils", "CUDA", "FiniteDifferences", "Random", "StableRNGs", "Test", "Zygote"]

src/NNlib.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,13 @@ is_nnpack_available() = false
2424
end
2525

2626
include("activations.jl")
27-
2827
include("softmax.jl")
29-
include("misc.jl")
3028
include("batched/batchedmul.jl")
3129
include("gemm.jl")
3230
include("conv.jl")
3331
include("conv_bias_act.jl")
3432
include("pooling.jl")
33+
include("upsample.jl")
3534

3635
## Include implementations
3736
include("impl/padding_edges.jl")

src/misc.jl

Lines changed: 0 additions & 25 deletions
This file was deleted.

src/upsample.jl

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
export upsample_bilinear, ∇upsample_bilinear, pixel_shuffle
2+
3+
"""
4+
upsample_bilinear(x::AbstractArray{<:Number,4}, k::NTuple{2,Int})
5+
6+
Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `k`,
7+
using bilinear interpolation.
8+
9+
The size of the output is equal to
10+
`(k[1]*S1, k[2]*S2, S3, S4)`, where `S1, S2, S3, S4 = size(x)`.
11+
12+
The interpolation grid is identical to the one used by `imresize` from `Images.jl`.
13+
14+
Currently only 2d upsampling is supported.
15+
"""
16+
function upsample_bilinear(x::AbstractArray{T,4}, k::NTuple{2,Int}) where T
17+
# This function is gpu friendly
18+
19+
imgsize = size(x)
20+
newsize = get_newsize(imgsize, k)
21+
22+
# Get linear interpolation lower- and upper index, and weights
23+
ilow1, ihigh1, wdiff1 = get_inds_and_ws(x, imgsize[1], newsize[1], 1)
24+
ilow2, ihigh2, wdiff2 = get_inds_and_ws(x, imgsize[2], newsize[2], 2)
25+
26+
# Adjust the upper interpolation indices of the second dimension
27+
ihigh2_r = adjoint_of_idx(ilow2)[ihigh2]
28+
29+
@inbounds y = @view(x[ilow1,ilow2,:,:]) .* (1 .- wdiff1) .+ @view(x[ihigh1,ilow2,:,:]) .* wdiff1
30+
@inbounds y .= y .* (1 .- wdiff2) .+ y[:,ihigh2_r,:,:] .* wdiff2
31+
# @inbounds y = y .* (1 .- wdiff2) .+ @view(y[:,ihigh2_r,:,:]) .* wdiff2 # equivalent to line above
32+
return y
33+
end
34+
35+
function get_inds_and_ws(x::T, n::Int, m::Int, dim::Int) where T <: AbstractArray
36+
# Creates interpolation grid for resampling.
37+
# Creates the same grid as used in Image.jl `imresize`.
38+
step = n // m
39+
offset = (n + 1)//2 - step//2 - step * (m//2 - 1)
40+
xq = clamp.(range(offset, step=step, length=m), 1, n)
41+
42+
# Creates interpolation lower and upper indices, and broadcastable weights
43+
ilow = floor.(Int, xq)
44+
ihigh = ceil.(Int, xq)
45+
sizew = ntuple(i-> i == dim ? length(xq) : 1, ndims(x))
46+
wdiff = convert(T, reshape(xq .- ilow, sizew)) # wdiff possibly lives on gpu
47+
return ilow, ihigh, wdiff
48+
end
49+
50+
"""
51+
adjoint_of_idx(idx::Vector{<:Integer})
52+
53+
# Arguments
54+
- `idx`: a vector of indices from which you want the adjoint.
55+
56+
# Outputs
57+
-`idx_adjoint`: index that inverses the operation `x[idx]`.
58+
59+
# Explanation
60+
Determines the adjoint of the vector of indices `idx`, based on the following assumptions:
61+
* `idx[1] == 1`
62+
* `all(d in [0,1] for d in diff(idx))`
63+
The adjoint of `idx` can be seen as an inverse operation such that:
64+
65+
```julia
66+
x = [1, 2, 3, 4, 5]
67+
idx = [1, 2, 2, 3, 4, 4, 5]
68+
idx_adjoint = adjoint_of_idx(idx)
69+
@assert x[idx][idx_adjoint] == x
70+
```
71+
The above holds as long as `idx` contains every index in `x`.
72+
"""
73+
function adjoint_of_idx(idx::Vector{Int})
74+
d = trues(length(idx))
75+
d[2:end] .= diff(idx)
76+
idx_adjoint = findall(d)
77+
return idx_adjoint
78+
end
79+
80+
function get_newsize(sz, k)
81+
return ntuple(i -> i <= length(k) ? sz[i]*k[i] : sz[i], length(sz))
82+
end
83+
84+
85+
"""
86+
∇upsample_bilinear(Δ::AbstractArray{<:Number,4}, k::NTuple{2,Int})
87+
88+
# Arguments
89+
- `Δ`: array that has been upsampled using the upsample factors in `k`
90+
91+
# Outputs
92+
- `dx`: downsampled version of `Δ`
93+
94+
# Explanation
95+
96+
Custom adjoint for [`upsample_bilinear`](@ref).
97+
The adjoint of upsampling is a downsampling operation, which
98+
in this implementation is performed using `NNlib.conv` in combination with a downsampling kernel based on the
99+
upsampling factors. Because of the zero-padding during convolution, the values at the boundary are polluted by edge-effects,
100+
which have been corrected for manually.
101+
"""
102+
function ∇upsample_bilinear::AbstractArray{<:Number, 4}, k::NTuple{2,Int})
103+
# This function is gpu friendly
104+
105+
# Be more efficient on some corner cases
106+
if size(Δ, 1) == k[1]
107+
Δ = sum(Δ, dims=1)
108+
k = (1, k[2])
109+
end
110+
if size(Δ, 2) == k[2]
111+
Δ = sum(Δ, dims=2)
112+
k = (k[1], 1)
113+
end
114+
if (size(Δ, 1) == 1) && (size(Δ, 2) == 1)
115+
dx = Δ
116+
return dx
117+
end
118+
119+
n_chan, n_batch = size(Δ, 3), size(Δ, 4)
120+
121+
kern1 = get_downsamplekernel(Δ, k[1])
122+
kern2 = get_downsamplekernel(Δ, k[2])
123+
kern = kern1 * kern2'
124+
125+
pad = (floor(Int, k[1]//2), floor(Int, k[2]//2))
126+
stride = k
127+
128+
weight = similar(Δ, eltype(Δ), (size(kern)..., n_chan, n_chan))
129+
weight .= 0
130+
for i in 1:n_chan
131+
weight[:,:,i,i] .= kern
132+
end
133+
# weight = cat(fill(kern, n_chan)..., dims=(3,4)) # slow
134+
dx = conv(Δ, weight, pad=pad, stride=stride)
135+
136+
# Still have to fix edge effects due to zero-padding of convolution,
137+
# TODO: Could be circumvented by having padding that just extrapolates the value at the first/last index
138+
# nextras = tuple((Int.(floor(factor//2)) for factor in k)...)
139+
nextras = (floor(Int, k[1]//2), floor(Int, k[2]//2))
140+
141+
# First dimension edge-effect correction
142+
if nextras[1] > 0
143+
kern1 = kern[1:nextras[1],:]
144+
pad1 = (0, pad[2])
145+
stride1 = (1, stride[2])
146+
weight1 = similar(Δ, eltype(Δ), (size(kern1)..., n_chan, n_chan))
147+
weight1 .= 0
148+
for i in 1:n_chan
149+
weight1[:,:,i,i] .= kern1
150+
end
151+
# weight1 = cat(fill(kern1, n_chan)..., dims=(3,4)) # slow
152+
dx[[1],:,:,:] .+= conv(Δ[1:nextras[1],:,:,:], weight1, pad=pad1, stride=stride1)
153+
weight1 .= weight1[end:-1:1,:,:,:]
154+
dx[[end],:,:,:] .+= conv(Δ[end-nextras[1]+1:end,:,:,:], weight1, pad=pad1, stride=stride1)
155+
156+
## Conv with views is not dispatched to CUDA.conv
157+
# dx[[1],:,:,:] .+= conv(@view(Δ[1:nextras[1],:,:,:]), weight1, pad=pad1, stride=stride1)
158+
# weight1 .= @view(weight1[end:-1:1,:,:,:])
159+
# dx[[end],:,:,:] .+= conv(@view(Δ[end-nextras[1]+1:end,:,:,:]), weight1, pad=pad1, stride=stride1)
160+
end
161+
162+
# Second dimension edge-effect correction
163+
if nextras[2] > 0
164+
kern2 = kern[:,1:nextras[2]]
165+
pad2 = (pad[1], 0)
166+
stride2 = (stride[1], 1)
167+
weight2 = similar(Δ, eltype(Δ), (size(kern2)..., n_chan, n_chan))
168+
weight2 .= 0
169+
for i in 1:n_chan
170+
weight2[:,:,i,i] .= kern2
171+
end
172+
# weight2 = cat(fill(kern2, n_chan)..., dims=(3,4)) # slow
173+
174+
yy = conv(Δ[:,1:nextras[2],:,:], weight2, pad=pad2, stride=stride2)
175+
dx[:,[1],:,:] .+= conv(Δ[:,1:nextras[2],:,:], weight2, pad=pad2, stride=stride2)
176+
weight2 .= weight2[:,end:-1:1,:,:]
177+
dx[:,[end],:,:] .+= conv(Δ[:,end-nextras[2]+1:end,:,:], weight2, pad=pad2, stride=stride2)
178+
179+
## Conv with views is not dispatched to CUDA.conv
180+
# yy = conv(@view(Δ[:,1:nextras[2],:,:]), weight2, pad=pad2, stride=stride2)
181+
# dx[:,[1],:,:] .+= conv(@view(Δ[:,1:nextras[2],:,:]), weight2, pad=pad2, stride=stride2)
182+
# weight2 .= @view(weight2[:,end:-1:1,:,:])
183+
# dx[:,[end],:,:] .+= conv(@view(Δ[:,end-nextras[2]+1:end,:,:]), weight2, pad=pad2, stride=stride2)
184+
end
185+
186+
## Finally fix four corners if needed
187+
n1, n2 = nextras
188+
if (n1 > 0) & (n2 > 0)
189+
dx[1,1,:,:] .+= sum(kern[1:n1,1:n2] .* @view(Δ[1:n1,1:n2,:,:]), dims=(1,2))[1,1,:,:]
190+
dx[1,end,:,:] .+= sum(kern[1:n1,end-n2+1:end] .* @view(Δ[1:n1,end-n2+1:end,:,:]), dims=(1,2))[1,1,:,:]
191+
dx[end,end,:,:] .+= sum(kern[end-n1+1:end,end-n2+1:end] .* @view(Δ[end-n1+1:end,end-n2+1:end,:,:]), dims=(1,2))[1,1,:,:]
192+
dx[end,1,:,:] .+= sum(kern[end-n1+1:end,1:n2] .* @view(Δ[end-n1+1:end,1:n2,:,:]), dims=(1,2))[1,1,:,:]
193+
end
194+
195+
return dx
196+
end
197+
198+
# `n` upsample factor for which a downsample kernel will be determined.
199+
# Δ is given in case of necessity of gpu conversion
200+
function get_downsamplekernel(Δ, n::Int)
201+
step = 1//n
202+
if n % 2 == 0
203+
start = step//2
204+
upward = collect(start:step:1//1)
205+
kernel = [upward; reverse(upward)]
206+
else
207+
start = step
208+
upward = collect(start:step:1//1)
209+
kernel = [upward; reverse(upward[1:end-1])]
210+
end
211+
# TODO there must be a more convenient way to send to gpu
212+
kernel = convert(typeof(Δ), reshape(kernel, length(kernel), 1, 1, 1))
213+
kernel = dropdims(kernel, dims=(2,3,4))
214+
return kernel
215+
end
216+
217+
function ChainRulesCore.rrule(::typeof(upsample_bilinear), x, k)
218+
Ω = upsample_bilinear(x, k)
219+
function upsample_bilinear_pullback(Δ)
220+
(NO_FIELDS, ∇upsample_bilinear(Δ, k), DoesNotExist())
221+
end
222+
return Ω, upsample_bilinear_pullback
223+
end
224+
225+
226+
"""
227+
pixel_shuffle(x, r)
228+
229+
Pixel shuffling operation. `r` is the upscale factor for shuffling.
230+
The operation converts an input of size [W,H,r²C,N] to size [rW,rH,C,N]
231+
Used extensively in super-resolution networks to upsample
232+
towards high resolution features.
233+
234+
Reference : https://arxiv.org/pdf/1609.05158.pdf
235+
"""
236+
function pixel_shuffle(x::AbstractArray, r::Integer)
237+
@assert ndims(x) > 2
238+
d = ndims(x) - 2
239+
sizein = size(x)[1:d]
240+
cin, n = size(x, d+1), size(x, d+2)
241+
@assert cin % r^d == 0
242+
cout = cin ÷ r^d
243+
# x = reshape(x, sizein..., fill(r, d)..., cout, n) # bug https://github.com/FluxML/Zygote.jl/issues/866
244+
x = reshape(x, sizein..., ntuple(i->r, d)..., cout, n)
245+
perm = [d+1:2d 1:d]' |> vec # = [d+1, 1, d+2, 2, ..., 2d, d]
246+
x = permutedims(x, (perm..., 2d+1, 2d+2))
247+
return reshape(x, ((r .* sizein)..., cout, n))
248+
end

test/runtests.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ using FiniteDifferences: FiniteDifferenceMethod, central_fdm
66
import Zygote
77
using Zygote: gradient
88
using StableRNGs
9+
using CUDA
10+
CUDA.allowscalar(false)
911

1012
const rng = StableRNG(123)
1113

@@ -36,6 +38,6 @@ end
3638
include("softmax.jl")
3739
end
3840

39-
@testset "Misc Stuff" begin
40-
include("misc.jl")
41+
@testset "Upsampling" begin
42+
include("upsample.jl")
4143
end

test/misc.jl renamed to test/upsample.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,32 @@
1+
@testset "upsample_bilinear 2d" begin
2+
x = reshape(Float32[1. 2.; 3. 4.], (2,2,1,1))
3+
y_true = [1//1 5//4 7//4 2//1;
4+
1//1 5//4 7//4 2//1;
5+
5//3 23//12 29//12 8//3;
6+
7//3 31//12 37//12 10//3;
7+
3//1 13//4 15//4 4//1;
8+
3//1 13//4 15//4 4//1][:,:,:,:]
9+
10+
y = upsample_bilinear(x, (3, 2))
11+
@test size(y) == size(y_true)
12+
@test eltype(y) == Float32
13+
@test y y_true
14+
15+
gradtest(x->upsample_bilinear(x, (3, 2)), x, atol=1e-4)
16+
17+
if CUDA.has_cuda()
18+
y = upsample_bilinear(x |> cu, (3, 2))
19+
@test y isa CuArray
20+
@test Array(y) y_true
21+
g_gpu = Zygote.gradient(x -> sum(sin.(upsample_bilinear(x, (3, 2))))
22+
, x |> cu)[1]
23+
@test g_gpu isa CuArray
24+
g_cpu = Zygote.gradient(x -> sum(sin.(upsample_bilinear(x, (3, 2))))
25+
, x)[1]
26+
@test Array(g_cpu) g_cpu atol=1e-4
27+
end
28+
end
29+
130
@testset "pixel_shuffle" begin
231
x = reshape(1:16, (2, 2, 4, 1))
332
# [:, :, 1, 1] =
@@ -61,3 +90,4 @@
6190
gradtest(x -> pixel_shuffle(x, r), x)
6291
end
6392
end
93+

0 commit comments

Comments
 (0)