|
| 1 | +export upsample_bilinear, ∇upsample_bilinear, pixel_shuffle |
| 2 | + |
| 3 | +""" |
| 4 | + upsample_bilinear(x::AbstractArray{<:Number,4}, k::NTuple{2,Int}) |
| 5 | +
|
| 6 | +Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `k`, |
| 7 | +using bilinear interpolation. |
| 8 | +
|
| 9 | +The size of the output is equal to |
| 10 | +`(k[1]*S1, k[2]*S2, S3, S4)`, where `S1, S2, S3, S4 = size(x)`. |
| 11 | +
|
| 12 | +The interpolation grid is identical to the one used by `imresize` from `Images.jl`. |
| 13 | +
|
| 14 | +Currently only 2d upsampling is supported. |
| 15 | +""" |
| 16 | +function upsample_bilinear(x::AbstractArray{T,4}, k::NTuple{2,Int}) where T |
| 17 | + # This function is gpu friendly |
| 18 | + |
| 19 | + imgsize = size(x) |
| 20 | + newsize = get_newsize(imgsize, k) |
| 21 | + |
| 22 | + # Get linear interpolation lower- and upper index, and weights |
| 23 | + ilow1, ihigh1, wdiff1 = get_inds_and_ws(x, imgsize[1], newsize[1], 1) |
| 24 | + ilow2, ihigh2, wdiff2 = get_inds_and_ws(x, imgsize[2], newsize[2], 2) |
| 25 | + |
| 26 | + # Adjust the upper interpolation indices of the second dimension |
| 27 | + ihigh2_r = adjoint_of_idx(ilow2)[ihigh2] |
| 28 | + |
| 29 | + @inbounds y = @view(x[ilow1,ilow2,:,:]) .* (1 .- wdiff1) .+ @view(x[ihigh1,ilow2,:,:]) .* wdiff1 |
| 30 | + @inbounds y .= y .* (1 .- wdiff2) .+ y[:,ihigh2_r,:,:] .* wdiff2 |
| 31 | + # @inbounds y = y .* (1 .- wdiff2) .+ @view(y[:,ihigh2_r,:,:]) .* wdiff2 # equivalent to line above |
| 32 | + return y |
| 33 | +end |
| 34 | + |
| 35 | +function get_inds_and_ws(x::T, n::Int, m::Int, dim::Int) where T <: AbstractArray |
| 36 | + # Creates interpolation grid for resampling. |
| 37 | + # Creates the same grid as used in Image.jl `imresize`. |
| 38 | + step = n // m |
| 39 | + offset = (n + 1)//2 - step//2 - step * (m//2 - 1) |
| 40 | + xq = clamp.(range(offset, step=step, length=m), 1, n) |
| 41 | + |
| 42 | + # Creates interpolation lower and upper indices, and broadcastable weights |
| 43 | + ilow = floor.(Int, xq) |
| 44 | + ihigh = ceil.(Int, xq) |
| 45 | + sizew = ntuple(i-> i == dim ? length(xq) : 1, ndims(x)) |
| 46 | + wdiff = convert(T, reshape(xq .- ilow, sizew)) # wdiff possibly lives on gpu |
| 47 | + return ilow, ihigh, wdiff |
| 48 | +end |
| 49 | + |
| 50 | +""" |
| 51 | + adjoint_of_idx(idx::Vector{<:Integer}) |
| 52 | +
|
| 53 | +# Arguments |
| 54 | +- `idx`: a vector of indices from which you want the adjoint. |
| 55 | +
|
| 56 | +# Outputs |
| 57 | +-`idx_adjoint`: index that inverses the operation `x[idx]`. |
| 58 | +
|
| 59 | +# Explanation |
| 60 | +Determines the adjoint of the vector of indices `idx`, based on the following assumptions: |
| 61 | +* `idx[1] == 1` |
| 62 | +* `all(d in [0,1] for d in diff(idx))` |
| 63 | +The adjoint of `idx` can be seen as an inverse operation such that: |
| 64 | +
|
| 65 | +```julia |
| 66 | +x = [1, 2, 3, 4, 5] |
| 67 | +idx = [1, 2, 2, 3, 4, 4, 5] |
| 68 | +idx_adjoint = adjoint_of_idx(idx) |
| 69 | +@assert x[idx][idx_adjoint] == x |
| 70 | +``` |
| 71 | +The above holds as long as `idx` contains every index in `x`. |
| 72 | +""" |
| 73 | +function adjoint_of_idx(idx::Vector{Int}) |
| 74 | + d = trues(length(idx)) |
| 75 | + d[2:end] .= diff(idx) |
| 76 | + idx_adjoint = findall(d) |
| 77 | + return idx_adjoint |
| 78 | +end |
| 79 | + |
| 80 | +function get_newsize(sz, k) |
| 81 | + return ntuple(i -> i <= length(k) ? sz[i]*k[i] : sz[i], length(sz)) |
| 82 | +end |
| 83 | + |
| 84 | + |
| 85 | +""" |
| 86 | + ∇upsample_bilinear(Δ::AbstractArray{<:Number,4}, k::NTuple{2,Int}) |
| 87 | + |
| 88 | +# Arguments |
| 89 | +- `Δ`: array that has been upsampled using the upsample factors in `k` |
| 90 | +
|
| 91 | +# Outputs |
| 92 | +- `dx`: downsampled version of `Δ` |
| 93 | +
|
| 94 | +# Explanation |
| 95 | +
|
| 96 | +Custom adjoint for [`upsample_bilinear`](@ref). |
| 97 | +The adjoint of upsampling is a downsampling operation, which |
| 98 | +in this implementation is performed using `NNlib.conv` in combination with a downsampling kernel based on the |
| 99 | +upsampling factors. Because of the zero-padding during convolution, the values at the boundary are polluted by edge-effects, |
| 100 | +which have been corrected for manually. |
| 101 | +""" |
| 102 | +function ∇upsample_bilinear(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int}) |
| 103 | + # This function is gpu friendly |
| 104 | + |
| 105 | + # Be more efficient on some corner cases |
| 106 | + if size(Δ, 1) == k[1] |
| 107 | + Δ = sum(Δ, dims=1) |
| 108 | + k = (1, k[2]) |
| 109 | + end |
| 110 | + if size(Δ, 2) == k[2] |
| 111 | + Δ = sum(Δ, dims=2) |
| 112 | + k = (k[1], 1) |
| 113 | + end |
| 114 | + if (size(Δ, 1) == 1) && (size(Δ, 2) == 1) |
| 115 | + dx = Δ |
| 116 | + return dx |
| 117 | + end |
| 118 | + |
| 119 | + n_chan, n_batch = size(Δ, 3), size(Δ, 4) |
| 120 | + |
| 121 | + kern1 = get_downsamplekernel(Δ, k[1]) |
| 122 | + kern2 = get_downsamplekernel(Δ, k[2]) |
| 123 | + kern = kern1 * kern2' |
| 124 | + |
| 125 | + pad = (floor(Int, k[1]//2), floor(Int, k[2]//2)) |
| 126 | + stride = k |
| 127 | + |
| 128 | + weight = similar(Δ, eltype(Δ), (size(kern)..., n_chan, n_chan)) |
| 129 | + weight .= 0 |
| 130 | + for i in 1:n_chan |
| 131 | + weight[:,:,i,i] .= kern |
| 132 | + end |
| 133 | + # weight = cat(fill(kern, n_chan)..., dims=(3,4)) # slow |
| 134 | + dx = conv(Δ, weight, pad=pad, stride=stride) |
| 135 | + |
| 136 | + # Still have to fix edge effects due to zero-padding of convolution, |
| 137 | + # TODO: Could be circumvented by having padding that just extrapolates the value at the first/last index |
| 138 | + # nextras = tuple((Int.(floor(factor//2)) for factor in k)...) |
| 139 | + nextras = (floor(Int, k[1]//2), floor(Int, k[2]//2)) |
| 140 | + |
| 141 | + # First dimension edge-effect correction |
| 142 | + if nextras[1] > 0 |
| 143 | + kern1 = kern[1:nextras[1],:] |
| 144 | + pad1 = (0, pad[2]) |
| 145 | + stride1 = (1, stride[2]) |
| 146 | + weight1 = similar(Δ, eltype(Δ), (size(kern1)..., n_chan, n_chan)) |
| 147 | + weight1 .= 0 |
| 148 | + for i in 1:n_chan |
| 149 | + weight1[:,:,i,i] .= kern1 |
| 150 | + end |
| 151 | + # weight1 = cat(fill(kern1, n_chan)..., dims=(3,4)) # slow |
| 152 | + dx[[1],:,:,:] .+= conv(Δ[1:nextras[1],:,:,:], weight1, pad=pad1, stride=stride1) |
| 153 | + weight1 .= weight1[end:-1:1,:,:,:] |
| 154 | + dx[[end],:,:,:] .+= conv(Δ[end-nextras[1]+1:end,:,:,:], weight1, pad=pad1, stride=stride1) |
| 155 | + |
| 156 | + ## Conv with views is not dispatched to CUDA.conv |
| 157 | + # dx[[1],:,:,:] .+= conv(@view(Δ[1:nextras[1],:,:,:]), weight1, pad=pad1, stride=stride1) |
| 158 | + # weight1 .= @view(weight1[end:-1:1,:,:,:]) |
| 159 | + # dx[[end],:,:,:] .+= conv(@view(Δ[end-nextras[1]+1:end,:,:,:]), weight1, pad=pad1, stride=stride1) |
| 160 | + end |
| 161 | + |
| 162 | + # Second dimension edge-effect correction |
| 163 | + if nextras[2] > 0 |
| 164 | + kern2 = kern[:,1:nextras[2]] |
| 165 | + pad2 = (pad[1], 0) |
| 166 | + stride2 = (stride[1], 1) |
| 167 | + weight2 = similar(Δ, eltype(Δ), (size(kern2)..., n_chan, n_chan)) |
| 168 | + weight2 .= 0 |
| 169 | + for i in 1:n_chan |
| 170 | + weight2[:,:,i,i] .= kern2 |
| 171 | + end |
| 172 | + # weight2 = cat(fill(kern2, n_chan)..., dims=(3,4)) # slow |
| 173 | + |
| 174 | + yy = conv(Δ[:,1:nextras[2],:,:], weight2, pad=pad2, stride=stride2) |
| 175 | + dx[:,[1],:,:] .+= conv(Δ[:,1:nextras[2],:,:], weight2, pad=pad2, stride=stride2) |
| 176 | + weight2 .= weight2[:,end:-1:1,:,:] |
| 177 | + dx[:,[end],:,:] .+= conv(Δ[:,end-nextras[2]+1:end,:,:], weight2, pad=pad2, stride=stride2) |
| 178 | + |
| 179 | + ## Conv with views is not dispatched to CUDA.conv |
| 180 | + # yy = conv(@view(Δ[:,1:nextras[2],:,:]), weight2, pad=pad2, stride=stride2) |
| 181 | + # dx[:,[1],:,:] .+= conv(@view(Δ[:,1:nextras[2],:,:]), weight2, pad=pad2, stride=stride2) |
| 182 | + # weight2 .= @view(weight2[:,end:-1:1,:,:]) |
| 183 | + # dx[:,[end],:,:] .+= conv(@view(Δ[:,end-nextras[2]+1:end,:,:]), weight2, pad=pad2, stride=stride2) |
| 184 | + end |
| 185 | + |
| 186 | + ## Finally fix four corners if needed |
| 187 | + n1, n2 = nextras |
| 188 | + if (n1 > 0) & (n2 > 0) |
| 189 | + dx[1,1,:,:] .+= sum(kern[1:n1,1:n2] .* @view(Δ[1:n1,1:n2,:,:]), dims=(1,2))[1,1,:,:] |
| 190 | + dx[1,end,:,:] .+= sum(kern[1:n1,end-n2+1:end] .* @view(Δ[1:n1,end-n2+1:end,:,:]), dims=(1,2))[1,1,:,:] |
| 191 | + dx[end,end,:,:] .+= sum(kern[end-n1+1:end,end-n2+1:end] .* @view(Δ[end-n1+1:end,end-n2+1:end,:,:]), dims=(1,2))[1,1,:,:] |
| 192 | + dx[end,1,:,:] .+= sum(kern[end-n1+1:end,1:n2] .* @view(Δ[end-n1+1:end,1:n2,:,:]), dims=(1,2))[1,1,:,:] |
| 193 | + end |
| 194 | + |
| 195 | + return dx |
| 196 | +end |
| 197 | + |
| 198 | +# `n` upsample factor for which a downsample kernel will be determined. |
| 199 | +# Δ is given in case of necessity of gpu conversion |
| 200 | +function get_downsamplekernel(Δ, n::Int) |
| 201 | + step = 1//n |
| 202 | + if n % 2 == 0 |
| 203 | + start = step//2 |
| 204 | + upward = collect(start:step:1//1) |
| 205 | + kernel = [upward; reverse(upward)] |
| 206 | + else |
| 207 | + start = step |
| 208 | + upward = collect(start:step:1//1) |
| 209 | + kernel = [upward; reverse(upward[1:end-1])] |
| 210 | + end |
| 211 | + # TODO there must be a more convenient way to send to gpu |
| 212 | + kernel = convert(typeof(Δ), reshape(kernel, length(kernel), 1, 1, 1)) |
| 213 | + kernel = dropdims(kernel, dims=(2,3,4)) |
| 214 | + return kernel |
| 215 | +end |
| 216 | + |
| 217 | +function ChainRulesCore.rrule(::typeof(upsample_bilinear), x, k) |
| 218 | + Ω = upsample_bilinear(x, k) |
| 219 | + function upsample_bilinear_pullback(Δ) |
| 220 | + (NO_FIELDS, ∇upsample_bilinear(Δ, k), DoesNotExist()) |
| 221 | + end |
| 222 | + return Ω, upsample_bilinear_pullback |
| 223 | +end |
| 224 | + |
| 225 | + |
| 226 | +""" |
| 227 | + pixel_shuffle(x, r) |
| 228 | + |
| 229 | +Pixel shuffling operation. `r` is the upscale factor for shuffling. |
| 230 | +The operation converts an input of size [W,H,r²C,N] to size [rW,rH,C,N] |
| 231 | +Used extensively in super-resolution networks to upsample |
| 232 | +towards high resolution features. |
| 233 | +
|
| 234 | +Reference : https://arxiv.org/pdf/1609.05158.pdf |
| 235 | +""" |
| 236 | +function pixel_shuffle(x::AbstractArray, r::Integer) |
| 237 | + @assert ndims(x) > 2 |
| 238 | + d = ndims(x) - 2 |
| 239 | + sizein = size(x)[1:d] |
| 240 | + cin, n = size(x, d+1), size(x, d+2) |
| 241 | + @assert cin % r^d == 0 |
| 242 | + cout = cin ÷ r^d |
| 243 | + # x = reshape(x, sizein..., fill(r, d)..., cout, n) # bug https://github.com/FluxML/Zygote.jl/issues/866 |
| 244 | + x = reshape(x, sizein..., ntuple(i->r, d)..., cout, n) |
| 245 | + perm = [d+1:2d 1:d]' |> vec # = [d+1, 1, d+2, 2, ..., 2d, d] |
| 246 | + x = permutedims(x, (perm..., 2d+1, 2d+2)) |
| 247 | + return reshape(x, ((r .* sizein)..., cout, n)) |
| 248 | +end |
0 commit comments