diff --git a/src/upsample.jl b/src/upsample.jl index 860e8441b..5126441a2 100644 --- a/src/upsample.jl +++ b/src/upsample.jl @@ -1,4 +1,68 @@ -export upsample_bilinear, ∇upsample_bilinear, pixel_shuffle +export upsample_nearest, ∇upsample_nearest, + upsample_bilinear, ∇upsample_bilinear, + pixel_shuffle + +""" + upsample_nearest(x::AbstractArray, scale::NTuple{S,Int}) + +Upsamples by integer multiples along the first `S` dimensions. +Subsequent dimensions of `x` are not altered. + +See also [`upsample_bilinear`](@ref), for two dimensions of an `N=4` array. + +# Example +```jldoctest +julia> upsample_nearest([1 2 3; 4 5 6], (2,3)) +4×9 Array{$Int,2}: + 1 1 1 2 2 2 3 3 3 + 1 1 1 2 2 2 3 3 3 + 4 4 4 5 5 5 6 6 6 + 4 4 4 5 5 5 6 6 6 + +julia> upsample_nearest([1 2 3; 4 5 6], (2,)) +4×3 Array{$Int,1}: + 1 2 3 + 1 2 3 + 4 5 6 + 4 5 6 +``` +""" +function upsample_nearest(x::AbstractArray{T,N}, scales::NTuple{S, <:Integer}) where {T,N,S} + S in 1:N || throw(ArgumentError("can't upsample ndims(x)=$N with scale=$scales")) + outsize = ntuple(d -> d<=S ? scales[d] * size(x,d) : size(x,d), N) + out = similar(x, T, outsize) + writesize = ntuple(N+S) do d + d > 2S && return size(x, d-S) + isodd(d) ? scales[cld(d,2)] : size(x, cld(d,2)) + end + readsize = ntuple(N+S) do d + d > 2S && return size(x, d-S) + isodd(d) ? 1 : size(x, cld(d,2)) + end + reshape(out, writesize) .= reshape(x, readsize) + out +end + +function ∇upsample_nearest(x::AbstractArray{T,N}, scales::NTuple{S, <:Integer}) where {T,N,S} + outsize = ntuple(N) do d + d > S && return size(x,d) + rem(size(x,d), scales[d]) == 0 || throw(ArgumentError("expected input array evenly divisible by scale=$scales, got size(x)=$(size(x))")) + div(size(x,d), scales[d]) + end + tempsize = ntuple(N+S) do d + d > 2S && return size(x, d-S) + s = scales[cld(d,2)] + isodd(d) ? s : div(size(x, cld(d,2)),s) + end + mid = sum(reshape(x, tempsize), dims=ntuple(d -> 2d-1, S)) + reshape(mid, outsize) +end + +function ChainRulesCore.rrule(::typeof(upsample_nearest), x::AbstractArray, s::Tuple) + Ω = upsample_nearest(x, s) + upsample_nearest_pullback(Δ) = (NO_FIELDS, ∇upsample_nearest(Δ, s), DoesNotExist()) + return Ω, upsample_nearest_pullback +end """ upsample_bilinear(x::AbstractArray{<:Number,4}, k::NTuple{2,Int}) diff --git a/test/upsample.jl b/test/upsample.jl index 6d07a1344..08886b4d0 100644 --- a/test/upsample.jl +++ b/test/upsample.jl @@ -1,3 +1,17 @@ +@testset "upsample_nearest, integer scale via reshape" begin + x = reshape(Float32[1. 2.; 3. 4.], (2,2,1,1)) + @test upsample_nearest(x, (3,3))[1,:] == [1,1,1, 2,2,2] + + y = upsample_nearest(x, (2,3)) + @test size(y) == (4,6,1,1) + ∇upsample_nearest(y, (2,3)) == [6 12; 18 24] + + gradtest(x -> upsample_nearest(x, (2,3)), rand(2,2,1,1), check_rrule=false) + + @test_throws ArgumentError ∇upsample_nearest(y, (2,4)) + @test_throws ArgumentError upsample_nearest(x, (1,2,3,4,5)) +end + @testset "upsample_bilinear 2d" begin x = reshape(Float32[1. 2.; 3. 4.], (2,2,1,1)) y_true = [1//1 5//4 7//4 2//1; @@ -90,4 +104,3 @@ end gradtest(x -> pixel_shuffle(x, r), x) end end -