1
- export bilinear_upsample, ∇bilinear_upsample
1
+ export bilinear_upsample, ∇bilinear_upsample, pixel_shuffle
2
2
3
3
"""
4
4
bilinear_upsample(x::AbstractArray{<:Number,4}, k::NTuple{2,Int})
@@ -118,13 +118,10 @@ function ∇bilinear_upsample(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int})
118
118
119
119
n_chan, n_batch = size (Δ, 3 ), size (Δ, 4 )
120
120
121
- kern1 = get_downsamplekernel (k[1 ])
122
- kern2 = get_downsamplekernel (k[2 ])
121
+ kern1 = get_downsamplekernel (Δ, k[1 ])
122
+ kern2 = get_downsamplekernel (Δ, k[2 ])
123
123
kern = kern1 .* kern2'
124
- # TODO there must be a more convenient way to send to gpu
125
- kern = convert (typeof (Δ), reshape (kern, size (kern)... , 1 , 1 ))
126
- kern = dropdims (kern, dims= (3 ,4 ))
127
-
124
+
128
125
pad = (floor (Int, k[1 ]// 2 ), floor (Int, k[2 ]// 2 ))
129
126
stride = k
130
127
weight = similar (Δ, eltype (Δ), (size (kern)... , n_chan, n_chan))
@@ -197,17 +194,9 @@ function ∇bilinear_upsample(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int})
197
194
return dx
198
195
end
199
196
200
-
201
- """
202
- get_downsamplekernel(n::Int)
203
-
204
- # Arguments
205
- - `n`: upsample factor for which a downsample kernel will be determined
206
-
207
- # Outputs
208
- - `kernel`: downsample kernel
209
- """
210
- function get_downsamplekernel (n:: Int )
197
+ # `n` upsample factor for which a downsample kernel will be determined.
198
+ # Δ is given in case of necessity of gpu conversion
199
+ function get_downsamplekernel (Δ, n:: Int )
211
200
step = 1 // n
212
201
if n % 2 == 0
213
202
start = step// 2
@@ -218,6 +207,9 @@ function get_downsamplekernel(n::Int)
218
207
upward = collect (start: step: 1 // 1 )
219
208
kernel = [upward; reverse (upward[1 : end - 1 ])]
220
209
end
210
+ # TODO there must be a more convenient way to send to gpu
211
+ kernel = convert (typeof (Δ), reshape (kernel, length (kernel), 1 , 1 , 1 ))
212
+ kernel = dropdims (kernel, dims= (2 ,3 ,4 ))
221
213
return kernel
222
214
end
223
215
@@ -228,3 +220,28 @@ function ChainRulesCore.rrule(::typeof(bilinear_upsample), x, k)
228
220
end
229
221
return Ω, bilinear_upsample_pullback
230
222
end
223
+
224
+
225
+ """
226
+ pixel_shuffle(x, r)
227
+
228
+ Pixel shuffling operation. `r` is the upscale factor for shuffling.
229
+ The operation converts an input of size [W,H,r²C,N] to size [rW,rH,C,N]
230
+ Used extensively in super-resolution networks to upsample
231
+ towards high resolution features.
232
+
233
+ Reference : https://arxiv.org/pdf/1609.05158.pdf
234
+ """
235
+ function pixel_shuffle (x:: AbstractArray , r:: Integer )
236
+ @assert ndims (x) > 2
237
+ d = ndims (x) - 2
238
+ sizein = size (x)[1 : d]
239
+ cin, n = size (x, d+ 1 ), size (x, d+ 2 )
240
+ @assert cin % r^ d == 0
241
+ cout = cin ÷ r^ d
242
+ # x = reshape(x, sizein..., fill(r, d)..., cout, n) # bug https://github.com/FluxML/Zygote.jl/issues/866
243
+ x = reshape (x, sizein... , ntuple (i-> r, d)... , cout, n)
244
+ perm = [d+ 1 : 2 d 1 : d]' |> vec # = [d+1, 1, d+2, 2, ..., 2d, d]
245
+ x = permutedims (x, (perm... , 2 d+ 1 , 2 d+ 2 ))
246
+ return reshape (x, ((r .* sizein). .. , cout, n))
247
+ end
0 commit comments