Skip to content

Commit 1ac7d2c

Browse files
cleanup
1 parent 38108ba commit 1ac7d2c

File tree

5 files changed

+100
-107
lines changed

5 files changed

+100
-107
lines changed

src/NNlib.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ end
2525

2626
include("activations.jl")
2727
include("softmax.jl")
28-
include("misc.jl")
2928
include("batched/batchedmul.jl")
3029
include("gemm.jl")
3130
include("conv.jl")

src/misc.jl

Lines changed: 0 additions & 25 deletions
This file was deleted.

src/upsample.jl

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export bilinear_upsample, ∇bilinear_upsample
1+
export bilinear_upsample, ∇bilinear_upsample, pixel_shuffle
22

33
"""
44
bilinear_upsample(x::AbstractArray{<:Number,4}, k::NTuple{2,Int})
@@ -118,13 +118,10 @@ function ∇bilinear_upsample(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int})
118118

119119
n_chan, n_batch = size(Δ, 3), size(Δ, 4)
120120

121-
kern1 = get_downsamplekernel(k[1])
122-
kern2 = get_downsamplekernel(k[2])
121+
kern1 = get_downsamplekernel(Δ, k[1])
122+
kern2 = get_downsamplekernel(Δ, k[2])
123123
kern = kern1 .* kern2'
124-
# TODO there must be a more convenient way to send to gpu
125-
kern = convert(typeof(Δ), reshape(kern, size(kern)..., 1, 1))
126-
kern = dropdims(kern, dims=(3,4))
127-
124+
128125
pad = (floor(Int, k[1]//2), floor(Int, k[2]//2))
129126
stride = k
130127
weight = similar(Δ, eltype(Δ), (size(kern)..., n_chan, n_chan))
@@ -197,17 +194,9 @@ function ∇bilinear_upsample(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int})
197194
return dx
198195
end
199196

200-
201-
"""
202-
get_downsamplekernel(n::Int)
203-
204-
# Arguments
205-
- `n`: upsample factor for which a downsample kernel will be determined
206-
207-
# Outputs
208-
- `kernel`: downsample kernel
209-
"""
210-
function get_downsamplekernel(n::Int)
197+
# `n` upsample factor for which a downsample kernel will be determined.
198+
# Δ is given in case of necessity of gpu conversion
199+
function get_downsamplekernel(Δ, n::Int)
211200
step = 1//n
212201
if n % 2 == 0
213202
start = step//2
@@ -218,6 +207,9 @@ function get_downsamplekernel(n::Int)
218207
upward = collect(start:step:1//1)
219208
kernel = [upward; reverse(upward[1:end-1])]
220209
end
210+
# TODO there must be a more convenient way to send to gpu
211+
kernel = convert(typeof(Δ), reshape(kernel, length(kernel), 1, 1, 1))
212+
kernel = dropdims(kernel, dims=(2,3,4))
221213
return kernel
222214
end
223215

@@ -228,3 +220,28 @@ function ChainRulesCore.rrule(::typeof(bilinear_upsample), x, k)
228220
end
229221
return Ω, bilinear_upsample_pullback
230222
end
223+
224+
225+
"""
226+
pixel_shuffle(x, r)
227+
228+
Pixel shuffling operation. `r` is the upscale factor for shuffling.
229+
The operation converts an input of size [W,H,r²C,N] to size [rW,rH,C,N]
230+
Used extensively in super-resolution networks to upsample
231+
towards high resolution features.
232+
233+
Reference : https://arxiv.org/pdf/1609.05158.pdf
234+
"""
235+
function pixel_shuffle(x::AbstractArray, r::Integer)
236+
@assert ndims(x) > 2
237+
d = ndims(x) - 2
238+
sizein = size(x)[1:d]
239+
cin, n = size(x, d+1), size(x, d+2)
240+
@assert cin % r^d == 0
241+
cout = cin ÷ r^d
242+
# x = reshape(x, sizein..., fill(r, d)..., cout, n) # bug https://github.com/FluxML/Zygote.jl/issues/866
243+
x = reshape(x, sizein..., ntuple(i->r, d)..., cout, n)
244+
perm = [d+1:2d 1:d]' |> vec # = [d+1, 1, d+2, 2, ..., 2d, d]
245+
x = permutedims(x, (perm..., 2d+1, 2d+2))
246+
return reshape(x, ((r .* sizein)..., cout, n))
247+
end

test/misc.jl

Lines changed: 0 additions & 63 deletions
This file was deleted.

test/upsample.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,68 @@
2626
@test Array(g_cpu) g_cpu atol=1e-4
2727
end
2828
end
29+
30+
@testset "pixel_shuffle" begin
31+
x = reshape(1:16, (2, 2, 4, 1))
32+
# [:, :, 1, 1] =
33+
# 1 3
34+
# 2 4
35+
# [:, :, 2, 1] =
36+
# 5 7
37+
# 6 8
38+
# [:, :, 3, 1] =
39+
# 9 11
40+
# 10 12
41+
# [:, :, 4, 1] =
42+
# 13 15
43+
# 14 16
44+
45+
y_true = [1 9 3 11
46+
5 13 7 15
47+
2 10 4 12
48+
6 14 8 16][:,:,:,:]
49+
50+
y = pixel_shuffle(x, 2)
51+
@test size(y) == size(y_true)
52+
@test y_true == y
53+
54+
x = reshape(1:32, (2, 2, 8, 1))
55+
y_true = zeros(Int, 4, 4, 2, 1)
56+
y_true[:,:,1,1] .= [ 1 9 3 11
57+
5 13 7 15
58+
2 10 4 12
59+
6 14 8 16 ]
60+
61+
y_true[:,:,2,1] .= [ 17 25 19 27
62+
21 29 23 31
63+
18 26 20 28
64+
22 30 24 32]
65+
66+
y = pixel_shuffle(x, 2)
67+
@test size(y) == size(y_true)
68+
@test y_true == y
69+
70+
x = reshape(1:4*3*27*2, (4,3,27,2))
71+
y = pixel_shuffle(x, 3)
72+
@test size(y) == (12, 9, 3, 2)
73+
# batch dimension is preserved
74+
x1 = x[:,:,:,[1]]
75+
x2 = x[:,:,:,[2]]
76+
y1 = pixel_shuffle(x1, 3)
77+
y2 = pixel_shuffle(x2, 3)
78+
@test cat(y1, y2, dims=4) == y
79+
80+
for d in [1, 2, 3]
81+
r = rand(1:5)
82+
n = rand(1:5)
83+
c = rand(1:5)
84+
insize = rand(1:5, d)
85+
x = rand(insize..., r^d*c, n)
86+
87+
y = pixel_shuffle(x, r)
88+
@test size(y) == ((r .* insize)..., c, n)
89+
90+
gradtest(x -> pixel_shuffle(x, r), x)
91+
end
92+
end
93+

0 commit comments

Comments
 (0)