From 5703cc8e53f88bbcbea2c42b971e4061f8cdbdb2 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 30 Dec 2020 15:28:42 +0100 Subject: [PATCH 01/12] add bilinear upsampling Co-authored-by: ltjkoomen --- src/upsample.jl | 263 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 263 insertions(+) create mode 100644 src/upsample.jl diff --git a/src/upsample.jl b/src/upsample.jl new file mode 100644 index 000000000..a8801e552 --- /dev/null +++ b/src/upsample.jl @@ -0,0 +1,263 @@ +export bilinear_upsample, ∇bilinear_upsample + +# Creates interpolation points for resampling, creates the same grid as used in Image.jl `imresize`. +function construct_xq(n::T, m::T) where T<:Integer + typed1 = one(n) + typed2 = 2typed1 + step = n // m + offset = (n + typed1)//typed2 - step//typed2 - step*(m//typed2 - typed1) + x = range(offset, step=step, length=m) + xq = clamp.(x, typed1//typed1, n//typed1) + return xq +end + +# Creates interpolation lower and upper indices, and broadcastable weights +function get_inds_and_ws(xq, dim) + n = length(xq) + ilow = floor.(Int, xq) + ihigh = ceil.(Int, xq) + wdiff = xq .- ilow + if dim == 1 + newsizetup = (n, 1, 1, 1) + elseif dim == 2 + newsizetup = (1, n, 1, 1) + else + error("Unreachable reached") + end + wdiff = reshape(wdiff, newsizetup) + return ilow, ihigh, wdiff +end + +""" + adjoint_of_idx(idx::Vector{<:Integer}) + +# Arguments +- `idx`: a vector of indices from which you want the adjoint. + +# Outputs +-`idx_adjoint`: index that inverses the operation `x[idx]`. + +# Explanation +Determines the adjoint of the vector of indices `idx`, based on the following assumptions: +* `idx[1] == 1` +* `all(d in [0,1] for d in diff(idx))` +The adjoint of `idx` can be seen as an inverse operation such that: + +```julia +x = [1, 2, 3, 4, 5] +idx = [1, 2, 2, 3, 4, 4, 5] +idx_adjoint = adjoint_of_idx(idx) +@assert x[idx][idx_adjoint] == x +``` +The above holds as long as `idx` contains every index in `x`. +""" +function adjoint_of_idx(idx::Vector{T}) where T<:Integer + d = trues(size(idx)) + d[2:end] .= diff(idx, dims=1) + idx_adjoint = findall(d) + return idx_adjoint +end + +function get_newsize(oldsize, k_upsample) + newsize = (i <= length(k_upsample) ? s*k_upsample[i] : s for (i,s) in enumerate(oldsize)) + return tuple(newsize...) +end + +""" + bilinear_upsample(img::AbstractArray{T,4}, k::NTuple{2,<:Real}) where T + +# Arguments +- `img::AbstractArray`: the array to be upsampled, must have at least 2 dimensions. +- `k_upsample`: a tuple containing the factors with which the first two dimensions of `img` are upsampled. + +# Outputs +- `imgupsampled`: the upsampled version of `img`. The size of `imgupsampled` is +equal to `(k_upsample[1]*S1, k_upsample[2]*S2, S3, S4)`, where `S1,S2,S3,S4 = size(img)`. + +# Explanation +Upsamples the first two dimensions of the 4-dimensional array `img` by the two upsample factors stored in `k_upsample`, +using bilinear interpolation. The interpolation grid is identical to the one used by `imresize` from `Images.jl`. +""" +function bilinear_upsample(img::AbstractArray{T,4}, k_upsample::NTuple{2,<:Real}) where T + + ilow1, ihigh1, wdiff1, ilow2, ihigh2, wdiff2, ihigh2_r = setup_upsample(img, k_upsample) + + @inbounds imgupsampled = @view(img[ilow1,ilow2,:,:]) .* (1 .- wdiff1) .+ @view(img[ihigh1,ilow2,:,:]) .* wdiff1 + @inbounds imgupsampled = imgupsampled .* (1 .- wdiff2) .+ @view(imgupsampled[:,ihigh2_r,:,:]) .* wdiff2 + + return imgupsampled +end + +""" + setup_upsample(imgsize::NTuple{4,<:Integer}, imgdtype, k_upsample::NTuple{2,<:Real}) + +Creates arrays of interpolation indices and weights for the bilinear_upsample2d operation. +""" +function setup_upsample(img, k_upsample::NTuple{2,<:Real}) + n_dims = 4 + imgsize = size(img) + newsize = get_newsize(imgsize, k_upsample) + + # Create interpolation grids + xq1 = construct_xq(imgsize[1], newsize[1]) + xq2 = construct_xq(imgsize[2], newsize[2]) + + # Get linear interpolation lower- and upper index, and weights + ilow1, ihigh1, wdiff1 = get_inds_and_ws(xq1, 1) + ilow2, ihigh2, wdiff2 = get_inds_and_ws(xq2, 2) + + # Adjust the upper interpolation indices of the second dimension + ihigh2_r = adjoint_of_idx(ilow2)[ihigh2] + + wdiff1 = eltype(img).(wdiff1) + wdiff2 = eltype(img).(wdiff2) + + # if typeof(img) <: CuArray + # wdiff1 = CuArray(wdiff1) + # wdiff2 = CuArray(wdiff2) + # end + + return ilow1, ihigh1, wdiff1, ilow2, ihigh2, wdiff2, ihigh2_r + +end + +""" + get_downsamplekernel(n::T) where T<:Integer + +# Arguments +- `n<:Integer`: upsample factor for which a downsample kernel will be determined + +# Outputs +- `kernel`: downsample kernel +""" +function get_downsamplekernel(n::T) where T<:Integer + step = 1//n + if n % 2 == 0 + start = step//2 + upward = collect(start:step:1//1) + kernel = [upward; reverse(upward)] + else + start = step + upward = collect(start:step:1//1) + kernel = [upward; reverse(upward[1:end-1])] + end + return kernel +end + +""" + ∇bilinear_upsample(arr::AbstractArray, factors::Tuple{T,T} where T<:Integer) + +# Arguments +- `arr::AbstractArray`: array that has been upsampled using the upsample factors in `factors` + +# Outputs +- `arr_ds`: downsampled version of `arr` + +# Explanation +Custom adjoint for `BilinearUpsample2d`. Needed because Zygote cannot properly determine gradients +for the current implementation of the forward pass. The adjoint of upsampling is a downsampling operation, which +in this implementation is performed using `Flux.Conv` in combination with a downsampling kernel based on the +upsampling factors. Because of the zero-padding during convolution, the values at the boundary are polluted by edge-effects, +which have been corrected for manually. +""" +function ∇bilinear_upsample(arr::AbstractArray, factors::Tuple{T,T} where T<:Integer) + + if size(arr,1) == factors[1] + arr = sum(arr, dims=1) + factors = (1, factors[2]) + end + + if size(arr,2) == factors[2] + arr = sum(arr, dims=2) + factors = (factors[1], 1) + end + + if (size(arr,1) == 1) & (size(arr,2) == 1) + ds_arr = arr + return ds_arr + end + + n_chan, n_batch = size(arr,3), size(arr,4) + + kern1 = get_downsamplekernel(factors[1]) + kern2 = get_downsamplekernel(factors[2]) + kern = kern1 .* kern2' + + kern_sizes = size(kern) + pads = (floor(Int, factors[1]//2), floor(Int, factors[2]//2)) + strides = factors + + conv_ds = Conv(kern_sizes, n_chan=>n_chan, pad=pads, stride=strides) + + conv_ds.weight .*= 0 + for i in 1:n_chan + conv_ds.weight[:,:,i,i] .= kern + end + conv_ds.bias .*= 0 + + # if arr isa CuArray + # conv_ds = gpu(conv_ds) + # end + + arr_ds = conv_ds(arr) + + # Still have to fix edge effects due to zero-padding of convolution, + # TODO: Could be circumvented by having padding that just extrapolates the value at the first/last index + # nextras = tuple((Int.(floor(factor//2)) for factor in factors)...) + nextras = (floor(Int, factors[1]//2), floor(Int, factors[2]//2)) + + # First dimension edge-effect correction + if nextras[1] > 0 + kern_extra1 = kern[1:nextras[1],:] + conv_extra1 = Conv(size(kern_extra1), n_chan=>n_chan, pad=(0,pads[2]), stride=(1,strides[2])) + + conv_extra1.weight .*= 0 + for i in 1:n_chan + conv_extra1.weight[:,:,i,i] .= kern_extra1 + end + conv_extra1.bias .*= 0 + + # if typeof(arr) <: CuArray + # conv_extra1 = gpu(conv_extra1) + # end + + arr_ds[[1],:,:,:] .+= conv_extra1(@view(arr[1:nextras[1],:,:,:])) + conv_extra1.weight .= @view(conv_extra1.weight[end:-1:1,:,:,:]) + arr_ds[[end],:,:,:] .+= conv_extra1(@view(arr[end-nextras[1]+1:end,:,:,:])) + end + + # Second dimension edge-effect correction + if nextras[2] > 0 + kern_extra2 = kern[:,1:nextras[2]] + conv_extra2 = Conv(size(kern_extra2), n_chan=>n_chan, pad=(pads[1],0), stride=(strides[1],1)) + + conv_extra2.weight .*= 0 + for i in 1:n_chan + conv_extra2.weight[:,:,i,i] .= kern_extra2 + end + conv_extra2.bias .*= 0 + + # if typeof(arr) <: CuArray + # conv_extra2 = gpu(conv_extra2) + # end + + arr_ds[:,[1],:,:] .+= conv_extra2(@view(arr[:,1:nextras[2],:,:])) + conv_extra2.weight .= @view(conv_extra2.weight[:,end:-1:1,:,:]) + arr_ds[:,[end],:,:] .+= conv_extra2(@view(arr[:,end-nextras[2]+1:end,:,:])) + end + + # Finally fix four corners if needed + # kern = eltype(arr).(kern) + # if typeof(arr) <: CuArray + # kern = gpu(kern) + # end + n1, n2 = nextras + if (n1 > 0) & (n2 > 0) + arr_ds[1,1,:,:] .+= sum(kern[1:n1,1:n2] .* @view(arr[1:n1,1:n2,:,:]), dims=(1,2))[1,1,:,:] + arr_ds[1,end,:,:] .+= sum(kern[1:n1,end-n2+1:end] .* @view(arr[1:n1,end-n2+1:end,:,:]), dims=(1,2))[1,1,:,:] + arr_ds[end,end,:,:] .+= sum(kern[end-n1+1:end,end-n2+1:end] .* @view(arr[end-n1+1:end,end-n2+1:end,:,:]), dims=(1,2))[1,1,:,:] + arr_ds[end,1,:,:] .+= sum(kern[end-n1+1:end,1:n2] .* @view(arr[end-n1+1:end,1:n2,:,:]), dims=(1,2))[1,1,:,:] + end + + return arr_ds +end From 69d62428c75e33b96dcc267f6e539346fe4fef60 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 30 Dec 2020 17:56:56 +0100 Subject: [PATCH 02/12] refactoring; working forward --- src/NNlib.jl | 2 +- src/upsample.jl | 170 ++++++++++++++++++----------------------------- test/runtests.jl | 4 +- test/upsample.jl | 20 ++++++ 4 files changed, 89 insertions(+), 107 deletions(-) create mode 100644 test/upsample.jl diff --git a/src/NNlib.jl b/src/NNlib.jl index 5f4851fd7..5e4df277a 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -24,7 +24,6 @@ is_nnpack_available() = false end include("activations.jl") - include("softmax.jl") include("misc.jl") include("batched/batchedmul.jl") @@ -32,6 +31,7 @@ include("gemm.jl") include("conv.jl") include("conv_bias_act.jl") include("pooling.jl") +include("upsample.jl") ## Include implementations include("impl/padding_edges.jl") diff --git a/src/upsample.jl b/src/upsample.jl index a8801e552..cfa263937 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -1,30 +1,47 @@ export bilinear_upsample, ∇bilinear_upsample -# Creates interpolation points for resampling, creates the same grid as used in Image.jl `imresize`. -function construct_xq(n::T, m::T) where T<:Integer - typed1 = one(n) - typed2 = 2typed1 - step = n // m - offset = (n + typed1)//typed2 - step//typed2 - step*(m//typed2 - typed1) - x = range(offset, step=step, length=m) - xq = clamp.(x, typed1//typed1, n//typed1) - return xq + +""" + bilinear_upsample(x::AbstractArray{<:Number,4}, k::NTuple{2,Int}) + +Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `k`, +using bilinear interpolation. + +The size of the output is equal to +`(k[1]*S1, k[2]*S2, S3, S4)`, where `S1, S2, S3, S4 = size(x)`. + +The interpolation grid is identical to the one used by `imresize` from `Images.jl`. + +Currently only 2d upsampling is supported. +""" +function bilinear_upsample(x::AbstractArray{<:Number,4}, k::NTuple{2,Int}) + imgsize = size(x) + newsize = get_newsize(imgsize, k) + + # Get linear interpolation lower- and upper index, and weights + ilow1, ihigh1, wdiff1 = get_inds_and_ws(x, imgsize[1], newsize[1], 1) + ilow2, ihigh2, wdiff2 = get_inds_and_ws(x, imgsize[2], newsize[2], 2) + + # Adjust the upper interpolation indices of the second dimension + ihigh2_r = adjoint_of_idx(ilow2)[ihigh2] + + @inbounds y = @view(x[ilow1,ilow2,:,:]) .* (1 .- wdiff1) .+ @view(x[ihigh1,ilow2,:,:]) .* wdiff1 + @inbounds y = y .* (1 .- wdiff2) .+ @view(y[:,ihigh2_r,:,:]) .* wdiff2 + return y end -# Creates interpolation lower and upper indices, and broadcastable weights -function get_inds_and_ws(xq, dim) - n = length(xq) +function get_inds_and_ws(x::T, n::Int, m::Int, dim::Int) where T <: AbstractArray + # Creates interpolation grid for resampling. + # Creates the same grid as used in Image.jl `imresize`. + step = n // m + offset = (n + 1)//2 - step//2 - step * (m//2 - 1) + xq = clamp.(range(offset, step=step, length=m), 1, n) + + # Creates interpolation lower and upper indices, and broadcastable weights ilow = floor.(Int, xq) ihigh = ceil.(Int, xq) - wdiff = xq .- ilow - if dim == 1 - newsizetup = (n, 1, 1, 1) - elseif dim == 2 - newsizetup = (1, n, 1, 1) - else - error("Unreachable reached") - end - wdiff = reshape(wdiff, newsizetup) + sizew = ntuple(i-> i == dim ? length(xq) : 1, ndims(x)) + wdiff = convert(T, reshape(xq .- ilow, sizew)) # wdiff possibly lives on gpu return ilow, ihigh, wdiff end @@ -58,91 +75,11 @@ function adjoint_of_idx(idx::Vector{T}) where T<:Integer return idx_adjoint end -function get_newsize(oldsize, k_upsample) - newsize = (i <= length(k_upsample) ? s*k_upsample[i] : s for (i,s) in enumerate(oldsize)) +function get_newsize(oldsize, k) + newsize = (i <= length(k) ? s*k[i] : s for (i,s) in enumerate(oldsize)) return tuple(newsize...) end -""" - bilinear_upsample(img::AbstractArray{T,4}, k::NTuple{2,<:Real}) where T - -# Arguments -- `img::AbstractArray`: the array to be upsampled, must have at least 2 dimensions. -- `k_upsample`: a tuple containing the factors with which the first two dimensions of `img` are upsampled. - -# Outputs -- `imgupsampled`: the upsampled version of `img`. The size of `imgupsampled` is -equal to `(k_upsample[1]*S1, k_upsample[2]*S2, S3, S4)`, where `S1,S2,S3,S4 = size(img)`. - -# Explanation -Upsamples the first two dimensions of the 4-dimensional array `img` by the two upsample factors stored in `k_upsample`, -using bilinear interpolation. The interpolation grid is identical to the one used by `imresize` from `Images.jl`. -""" -function bilinear_upsample(img::AbstractArray{T,4}, k_upsample::NTuple{2,<:Real}) where T - - ilow1, ihigh1, wdiff1, ilow2, ihigh2, wdiff2, ihigh2_r = setup_upsample(img, k_upsample) - - @inbounds imgupsampled = @view(img[ilow1,ilow2,:,:]) .* (1 .- wdiff1) .+ @view(img[ihigh1,ilow2,:,:]) .* wdiff1 - @inbounds imgupsampled = imgupsampled .* (1 .- wdiff2) .+ @view(imgupsampled[:,ihigh2_r,:,:]) .* wdiff2 - - return imgupsampled -end - -""" - setup_upsample(imgsize::NTuple{4,<:Integer}, imgdtype, k_upsample::NTuple{2,<:Real}) - -Creates arrays of interpolation indices and weights for the bilinear_upsample2d operation. -""" -function setup_upsample(img, k_upsample::NTuple{2,<:Real}) - n_dims = 4 - imgsize = size(img) - newsize = get_newsize(imgsize, k_upsample) - - # Create interpolation grids - xq1 = construct_xq(imgsize[1], newsize[1]) - xq2 = construct_xq(imgsize[2], newsize[2]) - - # Get linear interpolation lower- and upper index, and weights - ilow1, ihigh1, wdiff1 = get_inds_and_ws(xq1, 1) - ilow2, ihigh2, wdiff2 = get_inds_and_ws(xq2, 2) - - # Adjust the upper interpolation indices of the second dimension - ihigh2_r = adjoint_of_idx(ilow2)[ihigh2] - - wdiff1 = eltype(img).(wdiff1) - wdiff2 = eltype(img).(wdiff2) - - # if typeof(img) <: CuArray - # wdiff1 = CuArray(wdiff1) - # wdiff2 = CuArray(wdiff2) - # end - - return ilow1, ihigh1, wdiff1, ilow2, ihigh2, wdiff2, ihigh2_r - -end - -""" - get_downsamplekernel(n::T) where T<:Integer - -# Arguments -- `n<:Integer`: upsample factor for which a downsample kernel will be determined - -# Outputs -- `kernel`: downsample kernel -""" -function get_downsamplekernel(n::T) where T<:Integer - step = 1//n - if n % 2 == 0 - start = step//2 - upward = collect(start:step:1//1) - kernel = [upward; reverse(upward)] - else - start = step - upward = collect(start:step:1//1) - kernel = [upward; reverse(upward[1:end-1])] - end - return kernel -end """ ∇bilinear_upsample(arr::AbstractArray, factors::Tuple{T,T} where T<:Integer) @@ -160,7 +97,7 @@ in this implementation is performed using `Flux.Conv` in combination with a down upsampling factors. Because of the zero-padding during convolution, the values at the boundary are polluted by edge-effects, which have been corrected for manually. """ -function ∇bilinear_upsample(arr::AbstractArray, factors::Tuple{T,T} where T<:Integer) +function ∇bilinear_upsample(arr::AbstractArray{<:Number, 4}, factors::NTuple{2,Int}) if size(arr,1) == factors[1] arr = sum(arr, dims=1) @@ -261,3 +198,28 @@ function ∇bilinear_upsample(arr::AbstractArray, factors::Tuple{T,T} where T<:I return arr_ds end + + +""" + get_downsamplekernel(n::T) where T<:Integer + +# Arguments +- `n<:Integer`: upsample factor for which a downsample kernel will be determined + +# Outputs +- `kernel`: downsample kernel +""" +function get_downsamplekernel(n::T) where T<:Integer + step = 1//n + if n % 2 == 0 + start = step//2 + upward = collect(start:step:1//1) + kernel = [upward; reverse(upward)] + else + start = step + upward = collect(start:step:1//1) + kernel = [upward; reverse(upward[1:end-1])] + end + return kernel +end + diff --git a/test/runtests.jl b/test/runtests.jl index 0ef0676fc..6f431b9f6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -36,6 +36,6 @@ end include("softmax.jl") end -@testset "Misc Stuff" begin - include("misc.jl") +@testset "Upsampling" begin + include("upsample.jl") end diff --git a/test/upsample.jl b/test/upsample.jl new file mode 100644 index 000000000..2b2797a06 --- /dev/null +++ b/test/upsample.jl @@ -0,0 +1,20 @@ +using CUDA + +@testset "bilinear_upsample 2d" begin + x = reshape(Float32[1. 2.; 3. 4.], (2,2,1,1)) + y_true = [1//1 5//4 7//4 2//1; + 1//1 5//4 7//4 2//1; + 5//3 23//12 29//12 8//3; + 7//3 31//12 37//12 10//3; + 3//1 13//4 15//4 4//1; + 3//1 13//4 15//4 4//1][:,:,:,:] + + y = bilinear_upsample(x, (3, 2)) + @test size(y) == size(y_true) + @test eltype(y) == Float32 + @test y ≈ y_true + + y = bilinear_upsample(x |> cu, (3, 2)) + @test y isa CuArray + @test Array(y) ≈ y_true +end From 92de1d8712c4a8614d2b633bbc6afd68173fad15 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 30 Dec 2020 19:48:22 +0100 Subject: [PATCH 03/12] make it work --- Project.toml | 1 + src/upsample.jl | 173 ++++++++++++++++++++++++----------------------- test/upsample.jl | 10 +++ 3 files changed, 100 insertions(+), 84 deletions(-) diff --git a/Project.toml b/Project.toml index 7639e89b0..3136e07c1 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] diff --git a/src/upsample.jl b/src/upsample.jl index cfa263937..15c7c7bc5 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -1,6 +1,5 @@ export bilinear_upsample, ∇bilinear_upsample - """ bilinear_upsample(x::AbstractArray{<:Number,4}, k::NTuple{2,Int}) @@ -15,6 +14,8 @@ The interpolation grid is identical to the one used by `imresize` from `Images.j Currently only 2d upsampling is supported. """ function bilinear_upsample(x::AbstractArray{<:Number,4}, k::NTuple{2,Int}) + # This function is gpu friendly + imgsize = size(x) newsize = get_newsize(imgsize, k) @@ -68,148 +69,145 @@ idx_adjoint = adjoint_of_idx(idx) ``` The above holds as long as `idx` contains every index in `x`. """ -function adjoint_of_idx(idx::Vector{T}) where T<:Integer - d = trues(size(idx)) - d[2:end] .= diff(idx, dims=1) +function adjoint_of_idx(idx::Vector{Int}) + d = trues(length(idx)) + d[2:end] .= diff(idx) idx_adjoint = findall(d) return idx_adjoint end -function get_newsize(oldsize, k) - newsize = (i <= length(k) ? s*k[i] : s for (i,s) in enumerate(oldsize)) - return tuple(newsize...) +function get_newsize(sz, k) + return ntuple(i -> i <= length(k) ? sz[i]*k[i] : sz[i], length(sz)) end """ - ∇bilinear_upsample(arr::AbstractArray, factors::Tuple{T,T} where T<:Integer) + ∇bilinear_upsample(Δ::AbstractArray{<:Number,4}, k::NTuple{2,Int}) # Arguments -- `arr::AbstractArray`: array that has been upsampled using the upsample factors in `factors` +- `Δ`: array that has been upsampled using the upsample factors in `k` # Outputs -- `arr_ds`: downsampled version of `arr` +- `dx`: downsampled version of `Δ` # Explanation -Custom adjoint for `BilinearUpsample2d`. Needed because Zygote cannot properly determine gradients -for the current implementation of the forward pass. The adjoint of upsampling is a downsampling operation, which -in this implementation is performed using `Flux.Conv` in combination with a downsampling kernel based on the + +Custom adjoint for [`bilinear_upsample`](@ref). +The adjoint of upsampling is a downsampling operation, which +in this implementation is performed using `NNlib.conv` in combination with a downsampling kernel based on the upsampling factors. Because of the zero-padding during convolution, the values at the boundary are polluted by edge-effects, which have been corrected for manually. """ -function ∇bilinear_upsample(arr::AbstractArray{<:Number, 4}, factors::NTuple{2,Int}) - - if size(arr,1) == factors[1] - arr = sum(arr, dims=1) - factors = (1, factors[2]) +function ∇bilinear_upsample(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int}) + # This function is gpu friendly + + if size(Δ, 1) == k[1] + Δ = sum(Δ, dims=1) + k = (1, k[2]) end - if size(arr,2) == factors[2] - arr = sum(arr, dims=2) - factors = (factors[1], 1) + if size(Δ, 2) == k[2] + Δ = sum(Δ, dims=2) + k = (k[1], 1) end - if (size(arr,1) == 1) & (size(arr,2) == 1) - ds_arr = arr - return ds_arr + if (size(Δ, 1) == 1) & (size(Δ, 2) == 1) + dx = Δ + return dx end - n_chan, n_batch = size(arr,3), size(arr,4) + n_chan, n_batch = size(Δ,3), size(Δ,4) - kern1 = get_downsamplekernel(factors[1]) - kern2 = get_downsamplekernel(factors[2]) + kern1 = get_downsamplekernel(k[1]) + kern2 = get_downsamplekernel(k[2]) kern = kern1 .* kern2' - - kern_sizes = size(kern) - pads = (floor(Int, factors[1]//2), floor(Int, factors[2]//2)) - strides = factors - - conv_ds = Conv(kern_sizes, n_chan=>n_chan, pad=pads, stride=strides) - - conv_ds.weight .*= 0 + # TODO there must be a more convenient way to send to gpu + kern = convert(typeof(Δ), reshape(kern, size(kern)..., 1, 1)) + kern = dropdims(kern, dims=(3,4)) + + pad = (floor(Int, k[1]//2), floor(Int, k[2]//2)) + stride = k + weight = similar(Δ, eltype(Δ), (size(kern)..., n_chan, n_chan)) + weight .= 0 + for i in 1:n_chan - conv_ds.weight[:,:,i,i] .= kern + weight[:,:,i,i] .= kern end - conv_ds.bias .*= 0 - - # if arr isa CuArray - # conv_ds = gpu(conv_ds) - # end - - arr_ds = conv_ds(arr) + + dx = conv(Δ, weight, pad=pad, stride=stride) # Still have to fix edge effects due to zero-padding of convolution, # TODO: Could be circumvented by having padding that just extrapolates the value at the first/last index - # nextras = tuple((Int.(floor(factor//2)) for factor in factors)...) - nextras = (floor(Int, factors[1]//2), floor(Int, factors[2]//2)) + # nextras = tuple((Int.(floor(factor//2)) for factor in k)...) + nextras = (floor(Int, k[1]//2), floor(Int, k[2]//2)) # First dimension edge-effect correction if nextras[1] > 0 - kern_extra1 = kern[1:nextras[1],:] - conv_extra1 = Conv(size(kern_extra1), n_chan=>n_chan, pad=(0,pads[2]), stride=(1,strides[2])) - - conv_extra1.weight .*= 0 + kern1 = kern[1:nextras[1],:] + pad1 = (0, pad[2]) + stride1 = (1, stride[2]) + weight1 = similar(Δ, eltype(Δ), (size(kern1)..., n_chan, n_chan)) + weight1 .= 0 for i in 1:n_chan - conv_extra1.weight[:,:,i,i] .= kern_extra1 + weight1[:,:,i,i] .= kern1 end - conv_extra1.bias .*= 0 - - # if typeof(arr) <: CuArray - # conv_extra1 = gpu(conv_extra1) - # end - - arr_ds[[1],:,:,:] .+= conv_extra1(@view(arr[1:nextras[1],:,:,:])) - conv_extra1.weight .= @view(conv_extra1.weight[end:-1:1,:,:,:]) - arr_ds[[end],:,:,:] .+= conv_extra1(@view(arr[end-nextras[1]+1:end,:,:,:])) + + dx[[1],:,:,:] .+= conv(Δ[1:nextras[1],:,:,:], weight1, pad=pad1, stride=stride1) + weight1 .= weight1[end:-1:1,:,:,:] + dx[[end],:,:,:] .+= conv(Δ[end-nextras[1]+1:end,:,:,:], weight1, pad=pad1, stride=stride1) + + ## Conv with views is not dispatched to CUDA.conv + # dx[[1],:,:,:] .+= conv(@view(Δ[1:nextras[1],:,:,:]), weight1, pad=pad1, stride=stride1) + # weight1 .= @view(weight1[end:-1:1,:,:,:]) + # dx[[end],:,:,:] .+= conv(@view(Δ[end-nextras[1]+1:end,:,:,:]), weight1, pad=pad1, stride=stride1) end # Second dimension edge-effect correction if nextras[2] > 0 - kern_extra2 = kern[:,1:nextras[2]] - conv_extra2 = Conv(size(kern_extra2), n_chan=>n_chan, pad=(pads[1],0), stride=(strides[1],1)) - - conv_extra2.weight .*= 0 + kern2 = kern[:,1:nextras[2]] + pad2 = (pad[1], 0) + stride2 = (stride[1], 1) + weight2 = similar(Δ, eltype(Δ), (size(kern2)..., n_chan, n_chan)) + weight2 .= 0 for i in 1:n_chan - conv_extra2.weight[:,:,i,i] .= kern_extra2 + weight2[:,:,i,i] .= kern2 end - conv_extra2.bias .*= 0 - # if typeof(arr) <: CuArray - # conv_extra2 = gpu(conv_extra2) - # end + yy = conv(Δ[:,1:nextras[2],:,:], weight2, pad=pad2, stride=stride2) + dx[:,[1],:,:] .+= conv(Δ[:,1:nextras[2],:,:], weight2, pad=pad2, stride=stride2) + weight2 .= weight2[:,end:-1:1,:,:] + dx[:,[end],:,:] .+= conv(Δ[:,end-nextras[2]+1:end,:,:], weight2, pad=pad2, stride=stride2) - arr_ds[:,[1],:,:] .+= conv_extra2(@view(arr[:,1:nextras[2],:,:])) - conv_extra2.weight .= @view(conv_extra2.weight[:,end:-1:1,:,:]) - arr_ds[:,[end],:,:] .+= conv_extra2(@view(arr[:,end-nextras[2]+1:end,:,:])) + ## Conv with views is not dispatched to CUDA.conv + # yy = conv(@view(Δ[:,1:nextras[2],:,:]), weight2, pad=pad2, stride=stride2) + # dx[:,[1],:,:] .+= conv(@view(Δ[:,1:nextras[2],:,:]), weight2, pad=pad2, stride=stride2) + # weight2 .= @view(weight2[:,end:-1:1,:,:]) + # dx[:,[end],:,:] .+= conv(@view(Δ[:,end-nextras[2]+1:end,:,:]), weight2, pad=pad2, stride=stride2) end - # Finally fix four corners if needed - # kern = eltype(arr).(kern) - # if typeof(arr) <: CuArray - # kern = gpu(kern) - # end + ## Finally fix four corners if needed n1, n2 = nextras if (n1 > 0) & (n2 > 0) - arr_ds[1,1,:,:] .+= sum(kern[1:n1,1:n2] .* @view(arr[1:n1,1:n2,:,:]), dims=(1,2))[1,1,:,:] - arr_ds[1,end,:,:] .+= sum(kern[1:n1,end-n2+1:end] .* @view(arr[1:n1,end-n2+1:end,:,:]), dims=(1,2))[1,1,:,:] - arr_ds[end,end,:,:] .+= sum(kern[end-n1+1:end,end-n2+1:end] .* @view(arr[end-n1+1:end,end-n2+1:end,:,:]), dims=(1,2))[1,1,:,:] - arr_ds[end,1,:,:] .+= sum(kern[end-n1+1:end,1:n2] .* @view(arr[end-n1+1:end,1:n2,:,:]), dims=(1,2))[1,1,:,:] + dx[1,1,:,:] .+= sum(kern[1:n1,1:n2] .* @view(Δ[1:n1,1:n2,:,:]), dims=(1,2))[1,1,:,:] + dx[1,end,:,:] .+= sum(kern[1:n1,end-n2+1:end] .* @view(Δ[1:n1,end-n2+1:end,:,:]), dims=(1,2))[1,1,:,:] + dx[end,end,:,:] .+= sum(kern[end-n1+1:end,end-n2+1:end] .* @view(Δ[end-n1+1:end,end-n2+1:end,:,:]), dims=(1,2))[1,1,:,:] + dx[end,1,:,:] .+= sum(kern[end-n1+1:end,1:n2] .* @view(Δ[end-n1+1:end,1:n2,:,:]), dims=(1,2))[1,1,:,:] end - return arr_ds + return dx end """ - get_downsamplekernel(n::T) where T<:Integer + get_downsamplekernel(n::Int) # Arguments -- `n<:Integer`: upsample factor for which a downsample kernel will be determined +- `n`: upsample factor for which a downsample kernel will be determined # Outputs - `kernel`: downsample kernel """ -function get_downsamplekernel(n::T) where T<:Integer +function get_downsamplekernel(n::Int) step = 1//n if n % 2 == 0 start = step//2 @@ -223,3 +221,10 @@ function get_downsamplekernel(n::T) where T<:Integer return kernel end +function ChainRulesCore.rrule(::typeof(bilinear_upsample), x, k) + Ω = bilinear_upsample(x, k) + function bilinear_upsample_pullback(Δ) + (NO_FIELDS, ∇bilinear_upsample(Δ, k), DoesNotExist()) + end + return Ω, bilinear_upsample_pullback +end diff --git a/test/upsample.jl b/test/upsample.jl index 2b2797a06..eacd20d52 100644 --- a/test/upsample.jl +++ b/test/upsample.jl @@ -1,4 +1,5 @@ using CUDA +CUDA.allowscalar(false) @testset "bilinear_upsample 2d" begin x = reshape(Float32[1. 2.; 3. 4.], (2,2,1,1)) @@ -14,7 +15,16 @@ using CUDA @test eltype(y) == Float32 @test y ≈ y_true + gradtest(x->bilinear_upsample(x, (3, 2)), x, atol=1e-4) + + # CUDA compatibility y = bilinear_upsample(x |> cu, (3, 2)) @test y isa CuArray @test Array(y) ≈ y_true + g_gpu = Zygote.gradient(x -> sum(sin.(bilinear_upsample(x, (3, 2)))) + , x |> cu)[1] + @test g_gpu isa CuArray + g_cpu = Zygote.gradient(x -> sum(sin.(bilinear_upsample(x, (3, 2)))) + , x)[1] + @test Array(g_cpu) ≈ g_cpu atol=1e-4 end From 1cef99ab3d33b3a81b7b99a3c5236967bded7e7c Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 30 Dec 2020 23:44:33 +0100 Subject: [PATCH 04/12] cl/upsample --- Project.toml | 5 +++-- test/runtests.jl | 2 ++ test/upsample.jl | 24 +++++++++++------------- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 3136e07c1..7990c9978 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,6 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Requires = "ae029012-a4dd-5104-9daa-d747884805df" -StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] @@ -17,6 +16,7 @@ julia = "1.3" [extras] ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" @@ -24,4 +24,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ChainRulesTestUtils", "FiniteDifferences", "Random", "StableRNGs", "Test", "Zygote"] +test = ["ChainRulesTestUtils", "CUDA", "FiniteDifferences", + "Random", "StableRNGs", "Test", "Zygote"] diff --git a/test/runtests.jl b/test/runtests.jl index 6f431b9f6..16cb3394e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,6 +6,8 @@ using FiniteDifferences: FiniteDifferenceMethod, central_fdm import Zygote using Zygote: gradient using StableRNGs +using CUDA +CUDA.allowscalar(false) const rng = StableRNG(123) diff --git a/test/upsample.jl b/test/upsample.jl index eacd20d52..22108b9f5 100644 --- a/test/upsample.jl +++ b/test/upsample.jl @@ -1,6 +1,3 @@ -using CUDA -CUDA.allowscalar(false) - @testset "bilinear_upsample 2d" begin x = reshape(Float32[1. 2.; 3. 4.], (2,2,1,1)) y_true = [1//1 5//4 7//4 2//1; @@ -17,14 +14,15 @@ CUDA.allowscalar(false) gradtest(x->bilinear_upsample(x, (3, 2)), x, atol=1e-4) - # CUDA compatibility - y = bilinear_upsample(x |> cu, (3, 2)) - @test y isa CuArray - @test Array(y) ≈ y_true - g_gpu = Zygote.gradient(x -> sum(sin.(bilinear_upsample(x, (3, 2)))) - , x |> cu)[1] - @test g_gpu isa CuArray - g_cpu = Zygote.gradient(x -> sum(sin.(bilinear_upsample(x, (3, 2)))) - , x)[1] - @test Array(g_cpu) ≈ g_cpu atol=1e-4 + if CUDA.has_cuda() + y = bilinear_upsample(x |> cu, (3, 2)) + @test y isa CuArray + @test Array(y) ≈ y_true + g_gpu = Zygote.gradient(x -> sum(sin.(bilinear_upsample(x, (3, 2)))) + , x |> cu)[1] + @test g_gpu isa CuArray + g_cpu = Zygote.gradient(x -> sum(sin.(bilinear_upsample(x, (3, 2)))) + , x)[1] + @test Array(g_cpu) ≈ g_cpu atol=1e-4 + end end From 38108bab8bd1b2dd659708bea5f7f9a5037ea9a1 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 3 Jan 2021 23:27:44 +0100 Subject: [PATCH 05/12] small fix --- src/upsample.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/upsample.jl b/src/upsample.jl index 15c7c7bc5..3f4d1c125 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -27,7 +27,8 @@ function bilinear_upsample(x::AbstractArray{<:Number,4}, k::NTuple{2,Int}) ihigh2_r = adjoint_of_idx(ilow2)[ihigh2] @inbounds y = @view(x[ilow1,ilow2,:,:]) .* (1 .- wdiff1) .+ @view(x[ihigh1,ilow2,:,:]) .* wdiff1 - @inbounds y = y .* (1 .- wdiff2) .+ @view(y[:,ihigh2_r,:,:]) .* wdiff2 + @inbounds y .= y .* (1 .- wdiff2) .+ y[:,ihigh2_r,:,:] .* wdiff2 + # @inbounds y = y .* (1 .- wdiff2) .+ @view(y[:,ihigh2_r,:,:]) .* wdiff2 # equivalent to line above return y end @@ -100,23 +101,22 @@ which have been corrected for manually. """ function ∇bilinear_upsample(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int}) # This function is gpu friendly - + + # Be more efficient on some corner cases if size(Δ, 1) == k[1] Δ = sum(Δ, dims=1) k = (1, k[2]) end - if size(Δ, 2) == k[2] Δ = sum(Δ, dims=2) k = (k[1], 1) end - - if (size(Δ, 1) == 1) & (size(Δ, 2) == 1) + if (size(Δ, 1) == 1) && (size(Δ, 2) == 1) dx = Δ return dx end - n_chan, n_batch = size(Δ,3), size(Δ,4) + n_chan, n_batch = size(Δ, 3), size(Δ, 4) kern1 = get_downsamplekernel(k[1]) kern2 = get_downsamplekernel(k[2]) From 1ac7d2c950b042dddbf0bdaef6ef4c1f92588dc5 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 5 Jan 2021 17:18:45 +0100 Subject: [PATCH 06/12] cleanup --- src/NNlib.jl | 1 - src/misc.jl | 25 ------------------- src/upsample.jl | 53 +++++++++++++++++++++++++-------------- test/misc.jl | 63 ---------------------------------------------- test/upsample.jl | 65 ++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 100 insertions(+), 107 deletions(-) delete mode 100644 src/misc.jl delete mode 100644 test/misc.jl diff --git a/src/NNlib.jl b/src/NNlib.jl index 5e4df277a..d5c89dff1 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -25,7 +25,6 @@ end include("activations.jl") include("softmax.jl") -include("misc.jl") include("batched/batchedmul.jl") include("gemm.jl") include("conv.jl") diff --git a/src/misc.jl b/src/misc.jl deleted file mode 100644 index 02c9fa53d..000000000 --- a/src/misc.jl +++ /dev/null @@ -1,25 +0,0 @@ -export pixel_shuffle - -""" - pixel_shuffle(x, r) - -Pixel shuffling operation. `r` is the upscale factor for shuffling. -The operation converts an input of size [W,H,r²C,N] to size [rW,rH,C,N] -Used extensively in super-resolution networks to upsample -towards high resolution features. - -Reference : https://arxiv.org/pdf/1609.05158.pdf -""" -function pixel_shuffle(x::AbstractArray, r::Integer) - @assert ndims(x) > 2 - d = ndims(x) - 2 - sizein = size(x)[1:d] - cin, n = size(x, d+1), size(x, d+2) - @assert cin % r^d == 0 - cout = cin ÷ r^d - # x = reshape(x, sizein..., fill(r, d)..., cout, n) # bug https://github.com/FluxML/Zygote.jl/issues/866 - x = reshape(x, sizein..., ntuple(i->r, d)..., cout, n) - perm = [d+1:2d 1:d]' |> vec # = [d+1, 1, d+2, 2, ..., 2d, d] - x = permutedims(x, (perm..., 2d+1, 2d+2)) - return reshape(x, ((r .* sizein)..., cout, n)) -end diff --git a/src/upsample.jl b/src/upsample.jl index 3f4d1c125..294f96938 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -1,4 +1,4 @@ -export bilinear_upsample, ∇bilinear_upsample +export bilinear_upsample, ∇bilinear_upsample, pixel_shuffle """ bilinear_upsample(x::AbstractArray{<:Number,4}, k::NTuple{2,Int}) @@ -118,13 +118,10 @@ function ∇bilinear_upsample(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int}) n_chan, n_batch = size(Δ, 3), size(Δ, 4) - kern1 = get_downsamplekernel(k[1]) - kern2 = get_downsamplekernel(k[2]) + kern1 = get_downsamplekernel(Δ, k[1]) + kern2 = get_downsamplekernel(Δ, k[2]) kern = kern1 .* kern2' - # TODO there must be a more convenient way to send to gpu - kern = convert(typeof(Δ), reshape(kern, size(kern)..., 1, 1)) - kern = dropdims(kern, dims=(3,4)) - + pad = (floor(Int, k[1]//2), floor(Int, k[2]//2)) stride = k weight = similar(Δ, eltype(Δ), (size(kern)..., n_chan, n_chan)) @@ -197,17 +194,9 @@ function ∇bilinear_upsample(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int}) return dx end - -""" - get_downsamplekernel(n::Int) - -# Arguments -- `n`: upsample factor for which a downsample kernel will be determined - -# Outputs -- `kernel`: downsample kernel -""" -function get_downsamplekernel(n::Int) +# `n` upsample factor for which a downsample kernel will be determined. +# Δ is given in case of necessity of gpu conversion +function get_downsamplekernel(Δ, n::Int) step = 1//n if n % 2 == 0 start = step//2 @@ -218,6 +207,9 @@ function get_downsamplekernel(n::Int) upward = collect(start:step:1//1) kernel = [upward; reverse(upward[1:end-1])] end + # TODO there must be a more convenient way to send to gpu + kernel = convert(typeof(Δ), reshape(kernel, length(kernel), 1, 1, 1)) + kernel = dropdims(kernel, dims=(2,3,4)) return kernel end @@ -228,3 +220,28 @@ function ChainRulesCore.rrule(::typeof(bilinear_upsample), x, k) end return Ω, bilinear_upsample_pullback end + + +""" + pixel_shuffle(x, r) + +Pixel shuffling operation. `r` is the upscale factor for shuffling. +The operation converts an input of size [W,H,r²C,N] to size [rW,rH,C,N] +Used extensively in super-resolution networks to upsample +towards high resolution features. + +Reference : https://arxiv.org/pdf/1609.05158.pdf +""" +function pixel_shuffle(x::AbstractArray, r::Integer) + @assert ndims(x) > 2 + d = ndims(x) - 2 + sizein = size(x)[1:d] + cin, n = size(x, d+1), size(x, d+2) + @assert cin % r^d == 0 + cout = cin ÷ r^d + # x = reshape(x, sizein..., fill(r, d)..., cout, n) # bug https://github.com/FluxML/Zygote.jl/issues/866 + x = reshape(x, sizein..., ntuple(i->r, d)..., cout, n) + perm = [d+1:2d 1:d]' |> vec # = [d+1, 1, d+2, 2, ..., 2d, d] + x = permutedims(x, (perm..., 2d+1, 2d+2)) + return reshape(x, ((r .* sizein)..., cout, n)) +end diff --git a/test/misc.jl b/test/misc.jl deleted file mode 100644 index 2f2d724da..000000000 --- a/test/misc.jl +++ /dev/null @@ -1,63 +0,0 @@ -@testset "pixel_shuffle" begin - x = reshape(1:16, (2, 2, 4, 1)) - # [:, :, 1, 1] = - # 1 3 - # 2 4 - # [:, :, 2, 1] = - # 5 7 - # 6 8 - # [:, :, 3, 1] = - # 9 11 - # 10 12 - # [:, :, 4, 1] = - # 13 15 - # 14 16 - - y_true = [1 9 3 11 - 5 13 7 15 - 2 10 4 12 - 6 14 8 16][:,:,:,:] - - y = pixel_shuffle(x, 2) - @test size(y) == size(y_true) - @test y_true == y - - x = reshape(1:32, (2, 2, 8, 1)) - y_true = zeros(Int, 4, 4, 2, 1) - y_true[:,:,1,1] .= [ 1 9 3 11 - 5 13 7 15 - 2 10 4 12 - 6 14 8 16 ] - - y_true[:,:,2,1] .= [ 17 25 19 27 - 21 29 23 31 - 18 26 20 28 - 22 30 24 32] - - y = pixel_shuffle(x, 2) - @test size(y) == size(y_true) - @test y_true == y - - x = reshape(1:4*3*27*2, (4,3,27,2)) - y = pixel_shuffle(x, 3) - @test size(y) == (12, 9, 3, 2) - # batch dimension is preserved - x1 = x[:,:,:,[1]] - x2 = x[:,:,:,[2]] - y1 = pixel_shuffle(x1, 3) - y2 = pixel_shuffle(x2, 3) - @test cat(y1, y2, dims=4) == y - - for d in [1, 2, 3] - r = rand(1:5) - n = rand(1:5) - c = rand(1:5) - insize = rand(1:5, d) - x = rand(insize..., r^d*c, n) - - y = pixel_shuffle(x, r) - @test size(y) == ((r .* insize)..., c, n) - - gradtest(x -> pixel_shuffle(x, r), x) - end -end diff --git a/test/upsample.jl b/test/upsample.jl index 22108b9f5..13387e4ca 100644 --- a/test/upsample.jl +++ b/test/upsample.jl @@ -26,3 +26,68 @@ @test Array(g_cpu) ≈ g_cpu atol=1e-4 end end + +@testset "pixel_shuffle" begin + x = reshape(1:16, (2, 2, 4, 1)) + # [:, :, 1, 1] = + # 1 3 + # 2 4 + # [:, :, 2, 1] = + # 5 7 + # 6 8 + # [:, :, 3, 1] = + # 9 11 + # 10 12 + # [:, :, 4, 1] = + # 13 15 + # 14 16 + + y_true = [1 9 3 11 + 5 13 7 15 + 2 10 4 12 + 6 14 8 16][:,:,:,:] + + y = pixel_shuffle(x, 2) + @test size(y) == size(y_true) + @test y_true == y + + x = reshape(1:32, (2, 2, 8, 1)) + y_true = zeros(Int, 4, 4, 2, 1) + y_true[:,:,1,1] .= [ 1 9 3 11 + 5 13 7 15 + 2 10 4 12 + 6 14 8 16 ] + + y_true[:,:,2,1] .= [ 17 25 19 27 + 21 29 23 31 + 18 26 20 28 + 22 30 24 32] + + y = pixel_shuffle(x, 2) + @test size(y) == size(y_true) + @test y_true == y + + x = reshape(1:4*3*27*2, (4,3,27,2)) + y = pixel_shuffle(x, 3) + @test size(y) == (12, 9, 3, 2) + # batch dimension is preserved + x1 = x[:,:,:,[1]] + x2 = x[:,:,:,[2]] + y1 = pixel_shuffle(x1, 3) + y2 = pixel_shuffle(x2, 3) + @test cat(y1, y2, dims=4) == y + + for d in [1, 2, 3] + r = rand(1:5) + n = rand(1:5) + c = rand(1:5) + insize = rand(1:5, d) + x = rand(insize..., r^d*c, n) + + y = pixel_shuffle(x, r) + @test size(y) == ((r .* insize)..., c, n) + + gradtest(x -> pixel_shuffle(x, r), x) + end +end + From 1437c3d44c005602dc6a2bd1eb970e4e33e68397 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 5 Jan 2021 19:11:19 +0100 Subject: [PATCH 07/12] use cat --- Project.toml | 5 ++--- src/upsample.jl | 7 +------ 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 7990c9978..80bec9943 100644 --- a/Project.toml +++ b/Project.toml @@ -15,8 +15,8 @@ Requires = "0.5, 1.0" julia = "1.3" [extras] -ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" @@ -24,5 +24,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ChainRulesTestUtils", "CUDA", "FiniteDifferences", - "Random", "StableRNGs", "Test", "Zygote"] +test = ["ChainRulesTestUtils", "CUDA", "FiniteDifferences", "Random", "StableRNGs", "Test", "Zygote"] diff --git a/src/upsample.jl b/src/upsample.jl index 294f96938..317af5c05 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -124,12 +124,7 @@ function ∇bilinear_upsample(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int}) pad = (floor(Int, k[1]//2), floor(Int, k[2]//2)) stride = k - weight = similar(Δ, eltype(Δ), (size(kern)..., n_chan, n_chan)) - weight .= 0 - - for i in 1:n_chan - weight[:,:,i,i] .= kern - end + weight = cat(fill(kern, n_chan), dims=(3,4)) dx = conv(Δ, weight, pad=pad, stride=stride) From 82e30c57f0d177c5ce6d9305875fbd7b73bcc4bf Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 7 Jan 2021 08:06:46 +0100 Subject: [PATCH 08/12] remove cat --- src/upsample.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/upsample.jl b/src/upsample.jl index 317af5c05..5aff4f243 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -124,7 +124,13 @@ function ∇bilinear_upsample(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int}) pad = (floor(Int, k[1]//2), floor(Int, k[2]//2)) stride = k - weight = cat(fill(kern, n_chan), dims=(3,4)) + + weight = similar(Δ, eltype(Δ), (size(kern)..., n_chan, n_chan)) + weight .= 0 + for i in 1:n_chan + weight[:,:,i,i] .= kern + end + # weight = cat(fill(kern, n_chan), dims=(3,4)) # produces Array{Any}, revisit in the future dx = conv(Δ, weight, pad=pad, stride=stride) From 1dba5dba6bc909001039c0246ed2536b6423290e Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 7 Jan 2021 08:21:50 +0100 Subject: [PATCH 09/12] fix cat --- src/upsample.jl | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/upsample.jl b/src/upsample.jl index 5aff4f243..f161090d3 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -120,18 +120,12 @@ function ∇bilinear_upsample(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int}) kern1 = get_downsamplekernel(Δ, k[1]) kern2 = get_downsamplekernel(Δ, k[2]) - kern = kern1 .* kern2' + kern = kern1 * kern2' pad = (floor(Int, k[1]//2), floor(Int, k[2]//2)) stride = k - weight = similar(Δ, eltype(Δ), (size(kern)..., n_chan, n_chan)) - weight .= 0 - for i in 1:n_chan - weight[:,:,i,i] .= kern - end - # weight = cat(fill(kern, n_chan), dims=(3,4)) # produces Array{Any}, revisit in the future - + weight = cat(fill(kern, n_chan)..., dims=(3,4)) dx = conv(Δ, weight, pad=pad, stride=stride) # Still have to fix edge effects due to zero-padding of convolution, From 39cf0ca715b00762f0e7ef357da8a4f800a2b100 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 7 Jan 2021 08:34:49 +0100 Subject: [PATCH 10/12] more cat --- src/upsample.jl | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/src/upsample.jl b/src/upsample.jl index f161090d3..224cb5dab 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -138,12 +138,7 @@ function ∇bilinear_upsample(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int}) kern1 = kern[1:nextras[1],:] pad1 = (0, pad[2]) stride1 = (1, stride[2]) - weight1 = similar(Δ, eltype(Δ), (size(kern1)..., n_chan, n_chan)) - weight1 .= 0 - for i in 1:n_chan - weight1[:,:,i,i] .= kern1 - end - + weight1 = cat(fill(kern1, n_chan)..., dims=(3,4)) dx[[1],:,:,:] .+= conv(Δ[1:nextras[1],:,:,:], weight1, pad=pad1, stride=stride1) weight1 .= weight1[end:-1:1,:,:,:] dx[[end],:,:,:] .+= conv(Δ[end-nextras[1]+1:end,:,:,:], weight1, pad=pad1, stride=stride1) @@ -159,12 +154,8 @@ function ∇bilinear_upsample(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int}) kern2 = kern[:,1:nextras[2]] pad2 = (pad[1], 0) stride2 = (stride[1], 1) - weight2 = similar(Δ, eltype(Δ), (size(kern2)..., n_chan, n_chan)) - weight2 .= 0 - for i in 1:n_chan - weight2[:,:,i,i] .= kern2 - end - + weight2 = cat(fill(kern2, n_chan)..., dims=(3,4)) + yy = conv(Δ[:,1:nextras[2],:,:], weight2, pad=pad2, stride=stride2) dx[:,[1],:,:] .+= conv(Δ[:,1:nextras[2],:,:], weight2, pad=pad2, stride=stride2) weight2 .= weight2[:,end:-1:1,:,:] From caca2c814c25dba274db2ced5efe091ff0d6e4f6 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 7 Jan 2021 08:35:42 +0100 Subject: [PATCH 11/12] where T --- src/upsample.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/upsample.jl b/src/upsample.jl index 224cb5dab..5dbccb15b 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -13,7 +13,7 @@ The interpolation grid is identical to the one used by `imresize` from `Images.j Currently only 2d upsampling is supported. """ -function bilinear_upsample(x::AbstractArray{<:Number,4}, k::NTuple{2,Int}) +function bilinear_upsample(x::AbstractArray{T,4}, k::NTuple{2,Int}) where T # This function is gpu friendly imgsize = size(x) From 147f03dc48574382c09341235483d7e2d558ffbd Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 8 Jan 2021 00:56:05 +0100 Subject: [PATCH 12/12] change name; remove cat --- src/upsample.jl | 43 +++++++++++++++++++++++++++++-------------- test/upsample.jl | 12 ++++++------ 2 files changed, 35 insertions(+), 20 deletions(-) diff --git a/src/upsample.jl b/src/upsample.jl index 5dbccb15b..860e8441b 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -1,7 +1,7 @@ -export bilinear_upsample, ∇bilinear_upsample, pixel_shuffle +export upsample_bilinear, ∇upsample_bilinear, pixel_shuffle """ - bilinear_upsample(x::AbstractArray{<:Number,4}, k::NTuple{2,Int}) + upsample_bilinear(x::AbstractArray{<:Number,4}, k::NTuple{2,Int}) Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `k`, using bilinear interpolation. @@ -13,7 +13,7 @@ The interpolation grid is identical to the one used by `imresize` from `Images.j Currently only 2d upsampling is supported. """ -function bilinear_upsample(x::AbstractArray{T,4}, k::NTuple{2,Int}) where T +function upsample_bilinear(x::AbstractArray{T,4}, k::NTuple{2,Int}) where T # This function is gpu friendly imgsize = size(x) @@ -83,7 +83,7 @@ end """ - ∇bilinear_upsample(Δ::AbstractArray{<:Number,4}, k::NTuple{2,Int}) + ∇upsample_bilinear(Δ::AbstractArray{<:Number,4}, k::NTuple{2,Int}) # Arguments - `Δ`: array that has been upsampled using the upsample factors in `k` @@ -93,13 +93,13 @@ end # Explanation -Custom adjoint for [`bilinear_upsample`](@ref). +Custom adjoint for [`upsample_bilinear`](@ref). The adjoint of upsampling is a downsampling operation, which in this implementation is performed using `NNlib.conv` in combination with a downsampling kernel based on the upsampling factors. Because of the zero-padding during convolution, the values at the boundary are polluted by edge-effects, which have been corrected for manually. """ -function ∇bilinear_upsample(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int}) +function ∇upsample_bilinear(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int}) # This function is gpu friendly # Be more efficient on some corner cases @@ -125,7 +125,12 @@ function ∇bilinear_upsample(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int}) pad = (floor(Int, k[1]//2), floor(Int, k[2]//2)) stride = k - weight = cat(fill(kern, n_chan)..., dims=(3,4)) + weight = similar(Δ, eltype(Δ), (size(kern)..., n_chan, n_chan)) + weight .= 0 + for i in 1:n_chan + weight[:,:,i,i] .= kern + end + # weight = cat(fill(kern, n_chan)..., dims=(3,4)) # slow dx = conv(Δ, weight, pad=pad, stride=stride) # Still have to fix edge effects due to zero-padding of convolution, @@ -138,7 +143,12 @@ function ∇bilinear_upsample(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int}) kern1 = kern[1:nextras[1],:] pad1 = (0, pad[2]) stride1 = (1, stride[2]) - weight1 = cat(fill(kern1, n_chan)..., dims=(3,4)) + weight1 = similar(Δ, eltype(Δ), (size(kern1)..., n_chan, n_chan)) + weight1 .= 0 + for i in 1:n_chan + weight1[:,:,i,i] .= kern1 + end + # weight1 = cat(fill(kern1, n_chan)..., dims=(3,4)) # slow dx[[1],:,:,:] .+= conv(Δ[1:nextras[1],:,:,:], weight1, pad=pad1, stride=stride1) weight1 .= weight1[end:-1:1,:,:,:] dx[[end],:,:,:] .+= conv(Δ[end-nextras[1]+1:end,:,:,:], weight1, pad=pad1, stride=stride1) @@ -154,7 +164,12 @@ function ∇bilinear_upsample(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int}) kern2 = kern[:,1:nextras[2]] pad2 = (pad[1], 0) stride2 = (stride[1], 1) - weight2 = cat(fill(kern2, n_chan)..., dims=(3,4)) + weight2 = similar(Δ, eltype(Δ), (size(kern2)..., n_chan, n_chan)) + weight2 .= 0 + for i in 1:n_chan + weight2[:,:,i,i] .= kern2 + end + # weight2 = cat(fill(kern2, n_chan)..., dims=(3,4)) # slow yy = conv(Δ[:,1:nextras[2],:,:], weight2, pad=pad2, stride=stride2) dx[:,[1],:,:] .+= conv(Δ[:,1:nextras[2],:,:], weight2, pad=pad2, stride=stride2) @@ -199,12 +214,12 @@ function get_downsamplekernel(Δ, n::Int) return kernel end -function ChainRulesCore.rrule(::typeof(bilinear_upsample), x, k) - Ω = bilinear_upsample(x, k) - function bilinear_upsample_pullback(Δ) - (NO_FIELDS, ∇bilinear_upsample(Δ, k), DoesNotExist()) +function ChainRulesCore.rrule(::typeof(upsample_bilinear), x, k) + Ω = upsample_bilinear(x, k) + function upsample_bilinear_pullback(Δ) + (NO_FIELDS, ∇upsample_bilinear(Δ, k), DoesNotExist()) end - return Ω, bilinear_upsample_pullback + return Ω, upsample_bilinear_pullback end diff --git a/test/upsample.jl b/test/upsample.jl index 13387e4ca..6d07a1344 100644 --- a/test/upsample.jl +++ b/test/upsample.jl @@ -1,4 +1,4 @@ -@testset "bilinear_upsample 2d" begin +@testset "upsample_bilinear 2d" begin x = reshape(Float32[1. 2.; 3. 4.], (2,2,1,1)) y_true = [1//1 5//4 7//4 2//1; 1//1 5//4 7//4 2//1; @@ -7,21 +7,21 @@ 3//1 13//4 15//4 4//1; 3//1 13//4 15//4 4//1][:,:,:,:] - y = bilinear_upsample(x, (3, 2)) + y = upsample_bilinear(x, (3, 2)) @test size(y) == size(y_true) @test eltype(y) == Float32 @test y ≈ y_true - gradtest(x->bilinear_upsample(x, (3, 2)), x, atol=1e-4) + gradtest(x->upsample_bilinear(x, (3, 2)), x, atol=1e-4) if CUDA.has_cuda() - y = bilinear_upsample(x |> cu, (3, 2)) + y = upsample_bilinear(x |> cu, (3, 2)) @test y isa CuArray @test Array(y) ≈ y_true - g_gpu = Zygote.gradient(x -> sum(sin.(bilinear_upsample(x, (3, 2)))) + g_gpu = Zygote.gradient(x -> sum(sin.(upsample_bilinear(x, (3, 2)))) , x |> cu)[1] @test g_gpu isa CuArray - g_cpu = Zygote.gradient(x -> sum(sin.(bilinear_upsample(x, (3, 2)))) + g_cpu = Zygote.gradient(x -> sum(sin.(upsample_bilinear(x, (3, 2)))) , x)[1] @test Array(g_cpu) ≈ g_cpu atol=1e-4 end