Skip to content

Commit 51b9354

Browse files
committed
implement bilinear upsampling
1 parent 486141a commit 51b9354

File tree

4 files changed

+208
-1
lines changed

4 files changed

+208
-1
lines changed

lib/cudnn/CUDNN.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using ..APIUtils
55
using ..CUDA
66
using ..CUDA: CUstream, libraryPropertyType
77
using ..CUDA: libcudnn, @retry_reclaim, isdebug
8+
using ..CUDA: atomic_add!
89

910
using CEnum
1011

@@ -29,6 +30,9 @@ include("batchnorm.jl")
2930
include("dropout.jl")
3031
include("rnn.jl")
3132

33+
# custom kernels
34+
include("upsampling.jl")
35+
3236
# high-level integrations
3337
include("nnlib.jl")
3438

lib/cudnn/nnlib.jl

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ import NNlib: stride, padding, dilation, flipkernel, spatial_dims, kernel_size,
44
conv!, ∇conv_filter!, ∇conv_data!,
55
maxpool!, meanpool!, ∇maxpool!, ∇meanpool!,
66
softmax, softmax!, ∇softmax, ∇softmax!,
7-
logsoftmax, logsoftmax!, ∇logsoftmax, ∇logsoftmax!
7+
logsoftmax, logsoftmax!, ∇logsoftmax, ∇logsoftmax!,
8+
upsample_bilinear, ∇upsample_bilinear
89

910
import DataStructures: DefaultDict
1011

@@ -271,3 +272,39 @@ end
271272
# CUDNN_ACTIVATION_IDENTITY does not work with cudnnActivationForward
272273
# FIXME: put this optimization in GPUArrays' `copyto!` (like Base.Broadcast's `copyto!`)
273274
Base.broadcasted(::typeof(identity), x::DenseCuArray{T}) where {T<:CUDNNFloat} = x
275+
276+
277+
# Upsampling
278+
279+
function upsample_bilinear(x::CuArray{T,4}, scale::NTuple{2,Real}=(1,1); outsize::Union{Nothing,NTuple{2,Integer}}=nothing) where T ) where T
280+
w,h,c,n = size(x)
281+
width_scale, height_scale = Float32.(scale)
282+
if outsize===nothing
283+
out_w = floor(Int, width_scale*w)
284+
out_h = floor(Int, height_scale*h)
285+
else
286+
out_w, out_h = outsize
287+
end
288+
out_size = out_h*out_w
289+
nblocks = GET_BLOCKS(out_size)
290+
out = CuArray{T}(undef, out_w, out_h, c, n)
291+
CUDA.@sync @cuda blocks=nblocks threads=CUDA_NUM_THREADS upsample_bilinear_kernel!(n,c,h,w,out_h,out_w,height_scale,width_scale,x,out)
292+
return out
293+
end
294+
295+
function ∇upsample_bilinear::CuArray{T,4}, scale::NTuple{2,Real}=(1,1); outsize::Union{Nothing,NTuple{2,Integer}}=nothing) where T
296+
w,h,c,n = size(Δ)
297+
input_size = length(Δ)
298+
width_scale, height_scale = Float32.(scale)
299+
if outsize===nothing
300+
out_w = ceil(Int, w/width_scale)
301+
out_h = ceil(Int, h/height_scale)
302+
else
303+
out_w, out_h = outsize
304+
end
305+
out_size = out_h * out_w
306+
nblocks = GET_BLOCKS(out_size)
307+
dx = zero(CuArray{T}(undef, out_w, out_h, c, n))
308+
CUDA.@sync @cuda blocks=nblocks threads=CUDA_NUM_THREADS ∇upsample_bilinear_kernel(input_size, c, h, w, out_h, out_w, height_scale, width_scale, Δ, dx)
309+
return dx
310+
end

lib/cudnn/upsampling.jl

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# An implementation for GPU based bilinear upsampling including its gradient
2+
# The code is a translation from the following files:
3+
# https://github.com/pytorch/pytorch/blob/master/caffe2/operators/upsample_op.cu
4+
# https://github.com/pytorch/pytorch/blob/master/caffe2/core/common_gpu.h
5+
6+
# Forward and backward pass have been tested to produce the same output
7+
# as pytorch - it works modulo bit noise.
8+
9+
const CUDA_NUM_THREADS = 128
10+
const MAXIMUM_NUM_BLOCKS = 4096
11+
12+
@inline function GET_BLOCKS(N::Integer)
13+
# Use at least 1 block, since CUDA does not allow empty block
14+
return max(min((N + CUDA_NUM_THREADS - 1) ÷ CUDA_NUM_THREADS, MAXIMUM_NUM_BLOCKS), 1)
15+
end
16+
17+
# pytorch: nchw with row major
18+
# flux: whcn with column major -> same data layout in memory -> this function
19+
# can stay as it is except for +1
20+
@inline function idx(
21+
n::Integer,
22+
num_channels::Integer,
23+
c::Integer,
24+
height::Integer,
25+
width::Integer,
26+
y::Integer,
27+
x::Integer)
28+
return ((n * num_channels + c) * height + y) * width + x + 1
29+
end
30+
31+
function upsample_bilinear_kernel!(
32+
num_batch,
33+
num_channels,
34+
input_height,
35+
input_width,
36+
output_height,
37+
output_width,
38+
height_scale,
39+
width_scale,
40+
X, # input __restrict__
41+
Y) # output __restrict__
42+
out_size = output_height * output_width
43+
44+
# CUDA 1D kernel loop
45+
@inbounds for index in ((blockIdx().x-1) * blockDim().x + threadIdx().x-1) : (blockDim().x * gridDim().x) : out_size-1
46+
# mind the order!
47+
indexTemp = index
48+
out_x = indexTemp % output_width
49+
indexTemp ÷= output_width
50+
out_y = indexTemp % output_height
51+
indexTemp ÷= output_height
52+
indexTemp ÷= num_channels
53+
54+
rheight = output_height > 1 ? (input_height - 1f0) / (output_height - 1f0) : 0f0
55+
rwidth = output_width > 1 ? (input_width - 1f0) / (output_width - 1f0) : 0f0
56+
57+
# Compute Y axis lambdas
58+
h1r = rheight * out_y
59+
h1 = floor(Int, h1r) # here was a typecast (int)
60+
h1p = (h1 < input_height - 1) ? 1 : 0
61+
h1lambda = h1r - h1
62+
h0lambda = 1f0 - h1lambda
63+
64+
# Compute X axis lambdas
65+
w1r = rwidth * out_x
66+
w1 = floor(Int, w1r)
67+
w1p = (w1 < input_width - 1) ? 1 : 0
68+
w1lambda = w1r - w1
69+
w0lambda = 1f0 - w1lambda
70+
71+
for n in 0:num_batch-1 # shift to original C indexing
72+
for c in 0:num_channels-1
73+
X0 = X[idx(n, num_channels, c, input_height, input_width, h1, w1)]
74+
X1 = X[idx(n, num_channels, c, input_height, input_width, h1, w1 + w1p)]
75+
X2 = X[idx(n, num_channels, c, input_height, input_width, h1 + h1p, w1)]
76+
X3 = X[idx(n, num_channels, c, input_height, input_width, h1 + h1p, w1 + w1p)]
77+
78+
Y[idx(n, num_channels, c, output_height, output_width, out_y, out_x)] =
79+
h0lambda * (w0lambda * X0 + w1lambda * X1) +
80+
h1lambda * (w0lambda * X2 + w1lambda * X3)
81+
end # channels
82+
end # batch
83+
end # 1D kernel loop
84+
return nothing
85+
end
86+
87+
# input is dY, output is dX
88+
function ∇upsample_bilinear_kernel(
89+
input_size,
90+
num_channels,
91+
input_height,
92+
input_width,
93+
output_height,
94+
output_width,
95+
height_scale,
96+
width_scale,
97+
dY, # const
98+
dX)
99+
@inbounds for index in ((blockIdx().x - 1) * blockDim().x + threadIdx().x-1): blockDim().x * gridDim().x : input_size-1
100+
# mind the order!
101+
indexTemp = index
102+
in_x = indexTemp % input_width
103+
indexTemp ÷= input_width
104+
in_y = indexTemp % input_height
105+
indexTemp ÷= input_height
106+
c = indexTemp % num_channels
107+
indexTemp ÷= num_channels
108+
n = indexTemp
109+
110+
out_y = min(in_y / height_scale, output_height - 1)
111+
out_x = min(in_x / width_scale, output_width - 1)
112+
113+
rheight = output_height > 1 ? (output_height - 1.f0) / (input_height - 1.f0) : 0.f0
114+
rwidth = output_width > 1 ? (output_width - 1.f0) / (input_width - 1.f0) : 0.f0
115+
116+
# Compute Y axis lambdas
117+
h1r = rheight * in_y
118+
h1 = round(Int, h1r, RoundDown)
119+
h1p = (h1 < output_height - 1) ? 1 : 0
120+
h1lambda = h1r - h1
121+
h0lambda = 1.f0 - h1lambda
122+
123+
# Compute X axis lambdas
124+
w1r = rwidth * in_x
125+
w1 = round(Int, w1r, RoundDown)
126+
w1p = (w1 < output_width - 1) ? 1 : 0
127+
w1lambda = w1r - w1
128+
w0lambda = 1.f0 - w1lambda
129+
130+
#if __CUDA_ARCH__ >= 350 # true for everything from 9xx on
131+
dYi = ldg(dY, index+1) # ldg(pointer(dY[index])) ?
132+
#else
133+
# dYi = dY[index + 1];
134+
#endif
135+
136+
atomic_add!( pointer(dX, idx(n, num_channels, c, output_height, output_width, h1, w1)), h0lambda * w0lambda * dYi)
137+
atomic_add!( pointer(dX, idx(n, num_channels, c, output_height, output_width, h1, w1 + w1p)), h0lambda * w1lambda * dYi)
138+
atomic_add!( pointer(dX, idx(n, num_channels, c, output_height, output_width, h1 + h1p, w1)), h1lambda * w0lambda * dYi)
139+
atomic_add!( pointer(dX, idx(n, num_channels, c, output_height, output_width, h1 + h1p, w1 + w1p)), h1lambda * w1lambda * dYi)
140+
end
141+
return nothing
142+
end

test/cudnn.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,27 @@ end
109109
CUDNN.batchnorm(v, v, m, v, v, 1.0; training=training)
110110
end
111111
end
112+
113+
@testset "Bilinear upsampling" begin
114+
# only the forward pass is tested, more testing can be done in NNlib
115+
x = Float32[1 2; 3 4][:,:,:,:]
116+
x = cat(x,x; dims=3)
117+
x = cat(x,x; dims=4)
118+
x = cu(x)
119+
120+
y_true = Float32[ 1//1 4//3 5//3 2//1;
121+
7//5 26//15 31//15 12//5;
122+
9//5 32//15 37//15 14//5;
123+
11//5 38//15 43//15 16//5;
124+
13//5 44//15 49//15 18//5;
125+
3//1 10//3 11//3 4//1]
126+
y_true = cat(y_true,y_true; dims=3)
127+
y_true = cat(y_true,y_true; dims=4)
128+
y_true = cu(y_true)
129+
130+
y = upsample_bilinear(x, (3,2))
131+
132+
@test size(y) == size(y_true)
133+
@test eltype(y) == Float32
134+
@test y y_true
135+
end

0 commit comments

Comments
 (0)