Skip to content

Commit ce92c3a

Browse files
committed
move code to src/nnlib.jl
1 parent 41a0f85 commit ce92c3a

File tree

6 files changed

+154
-154
lines changed

6 files changed

+154
-154
lines changed

lib/cudnn/CUDNN.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,6 @@ include("rnn.jl")
3838
include("multiheadattn.jl")
3939
include("normalization.jl")
4040

41-
# custom kernels
42-
include("upsampling.jl")
43-
4441
# high-level integrations
4542
include("nnlib.jl")
4643
include("batchnorm.jl")

lib/cudnn/nnlib.jl

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ import NNlib: stride, padding, dilation, flipkernel, spatial_dims, kernel_size,
44
conv!, ∇conv_filter!, ∇conv_data!,
55
maxpool!, meanpool!, ∇maxpool!, ∇meanpool!, PoolDims,
66
softmax, softmax!, ∇softmax, ∇softmax!,
7-
logsoftmax, logsoftmax!, ∇logsoftmax, ∇logsoftmax!,
8-
upsample_bilinear_whcn!, ∇upsample_bilinear_whcn!
7+
logsoftmax, logsoftmax!, ∇logsoftmax, ∇logsoftmax!
98

109
import DataStructures: DefaultDict
1110

@@ -300,40 +299,6 @@ end
300299
Base.broadcasted(::typeof(identity), x::DenseCuArray{T}) where {T<:CUDNNFloat} = x
301300

302301

303-
# Upsampling
304-
305-
function upsample_bilinear_whcn!(y::CuArray{T,4}, x::CuArray{T,4}) where T
306-
w,h,c,n = size(x)
307-
out_w, out_h = (size(y,1), size(y,2))
308-
309-
out_size = out_h*out_w
310-
rheight = T((h-1)/(out_h-1))
311-
rwidth = T((w-1)/(out_w-1))
312-
313-
kernel = @cuda name="upsample_bilinear_whcn!" launch=false upsample_bilinear_whcn_kernel!(out_size, rheight, rwidth, x, y)
314-
config = launch_configuration(kernel.fun; max_threads=256)
315-
threads = Base.min(out_size, config.threads)
316-
blocks = cld(out_size, threads)
317-
kernel(out_size, rheight, rwidth, x, y; threads=threads, blocks=blocks)
318-
return y
319-
end
320-
321-
function ∇upsample_bilinear_whcn!(dx::CuArray{T,4}, Δ::CuArray{T,4}) where T
322-
w,h,c,n = Base.size(Δ)
323-
out_w, out_h = (size(dx, 1), size(dx, 2))
324-
in_size = h*w
325-
rheight = T((out_h-1)/(h-1)) # reversed compared to forward pass
326-
rwidth = T((out_w-1)/(w-1))
327-
328-
kernel = @cuda name="∇upsample_bilinear_whcn!" launch=false ∇upsample_bilinear_whcn_kernel!(in_size, rheight, rwidth, Δ, dx)
329-
config = launch_configuration(kernel.fun; max_threads=256)
330-
threads = Base.min(in_size, config.threads)
331-
blocks = cld(in_size, threads)
332-
kernel(in_size, rheight, rwidth, Δ, dx; threads=threads, blocks=blocks)
333-
return dx
334-
end
335-
336-
337302
# Compatibility shims until users upgrade to new NNlib format
338303
function conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}; pad=0, stride=1, flipkernel=0, dilation=1, kwargs...) where {T<:CUDNNFloat}
339304
cdims = DenseConvDims(x, w; padding=pad, stride=stride, flipkernel=(flipkernel!=0), dilation=dilation)

lib/cudnn/upsampling.jl

Lines changed: 0 additions & 88 deletions
This file was deleted.

src/nnlib.jl

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,129 @@ NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number,
2929

3030
Base.unsafe_convert(::Type{CuPtr{T}}, A::NNlib.BatchedAdjOrTrans{T}) where {T} =
3131
Base.unsafe_convert(CuPtr{T}, parent(A))
32+
33+
34+
#
35+
# Upsampling
36+
#
37+
38+
# An implementation for GPU based bilinear upsampling including its gradient
39+
# The code is a translation from the following files:
40+
# https://github.com/pytorch/pytorch/blob/master/caffe2/operators/upsample_op.cu
41+
# https://github.com/pytorch/pytorch/blob/master/caffe2/core/common_gpu.h
42+
43+
# Forward and backward pass have been tested to produce the same output
44+
# as pytorch with align_corners=True - it works modulo bit noise.
45+
46+
function upsample_bilinear_whcn_kernel!(n_elem, rheight, rwidth, x, y)
47+
index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x
48+
49+
if index < n_elem
50+
in_w, in_h, channels, batchsize = size(x)
51+
out_w, out_h, _, _ = size(y)
52+
53+
ow = index % out_w
54+
oh = index ÷ out_w
55+
56+
real_index = rheight*oh
57+
ih0 = floor(Int, real_index)
58+
offset = (ih0 < in_h-1) ? 1 : 0
59+
ih1 = ih0 + offset + 1
60+
h1lambda = real_index - ih0
61+
h0lambda = 1 - h1lambda
62+
ih0 += 1
63+
64+
real_index = rwidth*ow
65+
iw0 = floor(Int, real_index)
66+
offset = (iw0 < in_w-1) ? 1 : 0
67+
iw1 = iw0 + offset + 1
68+
w1lambda = real_index - iw0
69+
w0lambda = 1 - w1lambda
70+
iw0 += 1
71+
72+
@inbounds for n in 1:batchsize
73+
for c in 1:channels
74+
val = h0lambda * (w0lambda * x[iw0, ih0, c, n] + # h0 * w0 * i00
75+
w1lambda * x[iw1, ih0, c, n]) + # h0 * w1 * i01
76+
h1lambda * (w0lambda * x[iw0, ih1, c, n] + # h1 * w0 * i10
77+
w1lambda * x[iw1, ih1, c, n]) # h1 * w1 * i11
78+
y[ow+1, oh+1, c, n] = val
79+
end
80+
end
81+
end
82+
return nothing
83+
end
84+
85+
# Δ is the gradient backpropagated from downstream layers
86+
function ∇upsample_bilinear_whcn_kernel!(n_elem, rheight, rwidth, Δ, dx)
87+
index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x
88+
89+
if index < n_elem
90+
in_width, in_height, channels, batchsize = size(Δ)
91+
out_width, out_height, _, _ = size(dx)
92+
93+
iw = index % in_width
94+
ih = index ÷ in_width
95+
96+
# Compute Y axis lambdas
97+
real_index_h = rheight*ih
98+
oh0 = floor(Int, real_index_h)
99+
offset = (oh0 < out_height-1) ? 1 : 0
100+
oh1 = oh0 + offset + 1
101+
h1lambda = real_index_h - oh0
102+
h0lambda = 1 - h1lambda
103+
oh0 += 1
104+
105+
# # Compute X axis lambdas
106+
real_index_w = rwidth * iw
107+
ow0 = floor(Int, real_index_w)
108+
offset = (ow0 < out_width - 1) ? 1 : 0
109+
ow1 = ow0 + offset + 1
110+
w1lambda = real_index_w - ow0
111+
w0lambda = 1 - w1lambda
112+
ow0 += 1
113+
114+
@inbounds for n in 1:batchsize
115+
for c in 1:channels
116+
val = Δ[iw+1, ih+1, c, n]
117+
@atomic dx[ow0, oh0, c, n] += h0lambda * w0lambda * val
118+
@atomic dx[ow1, oh0, c, n] += h0lambda * w1lambda * val
119+
@atomic dx[ow0, oh1, c, n] += h1lambda * w0lambda * val
120+
@atomic dx[ow1, oh1, c, n] += h1lambda * w1lambda * val
121+
end
122+
end
123+
end # if
124+
return nothing
125+
end
126+
127+
128+
function NNlib.upsample_bilinear_whcn!(y::CuArray{T,4}, x::CuArray{T,4}) where T
129+
w,h,c,n = size(x)
130+
out_w, out_h = (size(y,1), size(y,2))
131+
132+
out_size = out_h*out_w
133+
rheight = T((h-1)/(out_h-1))
134+
rwidth = T((w-1)/(out_w-1))
135+
136+
kernel = @cuda name="upsample_bilinear_whcn!" launch=false upsample_bilinear_whcn_kernel!(out_size, rheight, rwidth, x, y)
137+
config = launch_configuration(kernel.fun; max_threads=256)
138+
threads = Base.min(out_size, config.threads)
139+
blocks = cld(out_size, threads)
140+
kernel(out_size, rheight, rwidth, x, y; threads=threads, blocks=blocks)
141+
return y
142+
end
143+
144+
function NNlib.∇upsample_bilinear_whcn!(dx::CuArray{T,4}, Δ::CuArray{T,4}) where T
145+
w,h,c,n = Base.size(Δ)
146+
out_w, out_h = (size(dx, 1), size(dx, 2))
147+
in_size = h*w
148+
rheight = T((out_h-1)/(h-1)) # reversed compared to forward pass
149+
rwidth = T((out_w-1)/(w-1))
150+
151+
kernel = @cuda name="∇upsample_bilinear_whcn!" launch=false ∇upsample_bilinear_whcn_kernel!(in_size, rheight, rwidth, Δ, dx)
152+
config = launch_configuration(kernel.fun; max_threads=256)
153+
threads = Base.min(in_size, config.threads)
154+
blocks = cld(in_size, threads)
155+
kernel(in_size, rheight, rwidth, Δ, dx; threads=threads, blocks=blocks)
156+
return dx
157+
end

test/cudnn/nnlib.jl

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -136,30 +136,3 @@ end
136136
CUDNN.batchnorm(v, v, m, v, v, 1.0; training=training)
137137
end
138138
end
139-
140-
@testset "Bilinear upsampling" begin
141-
x = Float32[1 2; 3 4][:,:,:,:]
142-
x = cat(x,x; dims=3)
143-
x = cat(x,x; dims=4)
144-
x = cu(x)
145-
146-
y_true = Float32[ 1//1 4//3 5//3 2//1;
147-
7//5 26//15 31//15 12//5;
148-
9//5 32//15 37//15 14//5;
149-
11//5 38//15 43//15 16//5;
150-
13//5 44//15 49//15 18//5;
151-
3//1 10//3 11//3 4//1]
152-
y_true = cat(y_true,y_true; dims=3)
153-
y_true = cat(y_true,y_true; dims=4)
154-
y_true = cu(y_true)
155-
156-
y = upsample_bilinear(x, (3,2))
157-
158-
@test size(y) == size(y_true)
159-
@test eltype(y) == Float32
160-
@test y y_true
161-
162-
o = CUDA.ones(Float32,6,4,2,1)
163-
grad_true = 6*CUDA.ones(Float32,2,2,2,1)
164-
@test ∇upsample_bilinear(o; size=(2,2)) grad_true
165-
end

test/nnlib.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,30 @@ end
6262
@test testf(x -> logσ.(x), rand(5))
6363
end
6464
end
65+
66+
@testset "Bilinear upsampling" begin
67+
x = Float32[1 2; 3 4][:,:,:,:]
68+
x = cat(x,x; dims=3)
69+
x = cat(x,x; dims=4)
70+
x = cu(x)
71+
72+
y_true = Float32[ 1//1 4//3 5//3 2//1;
73+
7//5 26//15 31//15 12//5;
74+
9//5 32//15 37//15 14//5;
75+
11//5 38//15 43//15 16//5;
76+
13//5 44//15 49//15 18//5;
77+
3//1 10//3 11//3 4//1]
78+
y_true = cat(y_true,y_true; dims=3)
79+
y_true = cat(y_true,y_true; dims=4)
80+
y_true = cu(y_true)
81+
82+
y = upsample_bilinear(x, (3,2))
83+
84+
@test size(y) == size(y_true)
85+
@test eltype(y) == Float32
86+
@test y y_true
87+
88+
o = CUDA.ones(Float32,6,4,2,1)
89+
grad_true = 6*CUDA.ones(Float32,2,2,2,1)
90+
@test ∇upsample_bilinear(o; size=(2,2)) grad_true
91+
end

0 commit comments

Comments
 (0)