Skip to content

Upsample docstring improvements #278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 18, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 88 additions & 15 deletions src/upsample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,34 @@ can be directly specified with a keyword argument.
The size of the output is equal to
`(scale[1]*S1, scale[2]*S2, S3, S4)`, where `S1, S2, S3, S4 = size(x)`.

Examples:
```julia
upsample_bilinear(x, (2, pi)) # real scaling factors are allowed
upsample_bilinear(x; size=(64,64)) # specify ouput size
# Examples

```jldoctest
julia> x = reshape(Float32[1 2 3; 4 5 6], (2,3,1,1))
2×3×1×1 Array{Float32,4}:
[:, :, 1, 1] =
1.0 2.0 3.0
4.0 5.0 6.0

julia> upsample_bilinear(x, (2, 3))
4×9×1×1 Array{Float32,4}:
[:, :, 1, 1] =
1.0 1.25 1.5 1.75 2.0 2.25 2.5 2.75 3.0
2.0 2.25 2.5 2.75 3.0 3.25 3.5 3.75 4.0
3.0 3.25 3.5 3.75 4.0 4.25 4.5 4.75 5.0
4.0 4.25 4.5 4.75 5.0 5.25 5.5 5.75 6.0

julia> ans == upsample_bilinear(x; size=(4, 9)) # specify ouput size instead
true

julia> upsample_bilinear(x, (2.5, 3.5)) # non-integer scaling factors are allowed
5×10×1×1 Array{Float32,4}:
[:, :, 1, 1] =
1.0 1.22222 1.44444 1.66667 1.88889 2.11111 2.33333 2.55556 2.77778 3.0
1.75 1.97222 2.19444 2.41667 2.63889 2.86111 3.08333 3.30556 3.52778 3.75
2.5 2.72222 2.94444 3.16667 3.38889 3.61111 3.83333 4.05556 4.27778 4.5
3.25 3.47222 3.69444 3.91667 4.13889 4.36111 4.58333 4.80556 5.02778 5.25
4.0 4.22222 4.44444 4.66667 4.88889 5.11111 5.33333 5.55556 5.77778 6.0
```
"""
function upsample_bilinear(x::AbstractArray{<:Any,4}, scale::NTuple{2,Real})
Expand Down Expand Up @@ -234,25 +258,74 @@ function ChainRulesCore.rrule(::typeof(upsample_bilinear), x; size)
end

"""
pixel_shuffle(x, r)
pixel_shuffle(x, r::Integer)

Pixel shuffling operation, upscaling by a factor `r`.

Pixel shuffling operation. `r` is the upscale factor for shuffling.
The operation converts an input of size [W,H,r²C,N] to size [rW,rH,C,N]
Used extensively in super-resolution networks to upsample
towards high resolution features.
For 4-arrays representing `N` images, the operation converts input `size(x) == (W, H, r^2*C, N)`
to output of size `(r*W, r*H, C, N)`. For `D`-dimensional data, it expects `ndims(x) == D+2`
with channel and batch dimensions, and divides the number of channels by `r^D`.

Reference : https://arxiv.org/pdf/1609.05158.pdf
Used in super-resolution networks to upsample towards high resolution features.
Reference: Shi et. al., "Real-Time Single Image and Video Super-Resolution ...", CVPR 2016, https://arxiv.org/abs/1609.05158

# Examples

```jldoctest
julia> x = [10i + j + channel/10 for i in 1:2, j in 1:3, channel in 1:4, batch in 1:1]
2×3×4×1 Array{Float64,4}:
[:, :, 1, 1] =
11.1 12.1 13.1
21.1 22.1 23.1

[:, :, 2, 1] =
11.2 12.2 13.2
21.2 22.2 23.2

[:, :, 3, 1] =
11.3 12.3 13.3
21.3 22.3 23.3

[:, :, 4, 1] =
11.4 12.4 13.4
21.4 22.4 23.4

julia> pixel_shuffle(x, 2) # 4 channels used up as 2x upscaling of image dimensions
4×6×1×1 Array{Float64,4}:
[:, :, 1, 1] =
11.1 11.3 12.1 12.3 13.1 13.3
11.2 11.4 12.2 12.4 13.2 13.4
21.1 21.3 22.1 22.3 23.1 23.3
21.2 21.4 22.2 22.4 23.2 23.4

julia> y = [i + channel/10 for i in 1:3, channel in 1:6, batch in 1:1]
3×6×1 Array{Float64, 3}:
[:, :, 1] =
1.1 1.2 1.3 1.4 1.5 1.6
2.1 2.2 2.3 2.4 2.5 2.6
3.1 3.2 3.3 3.4 3.5 3.6

julia> pixel_shuffle(y, 2) # 1D image, with 6 channels reduced to 3
6×3×1 Array{Float64,3}:
[:, :, 1] =
1.1 1.3 1.5
1.2 1.4 1.6
2.1 2.3 2.5
2.2 2.4 2.6
3.1 3.3 3.5
3.2 3.4 3.6
```
"""
function pixel_shuffle(x::AbstractArray, r::Integer)
@assert ndims(x) > 2
ndims(x) > 2 || throw(ArgumentError("expected x with at least 3 dimensions"))
d = ndims(x) - 2
sizein = size(x)[1:d]
cin, n = size(x, d+1), size(x, d+2)
@assert cin % r^d == 0
cin % r^d == 0 || throw(ArgumentError("expected channel dimension to be divisible by r^d = $(
r^d), where d=$d is the number of spatial dimensions. Given r=$r, input size(x) = $(size(x))"))
cout = cin ÷ r^d
# x = reshape(x, sizein..., fill(r, d)..., cout, n) # bug https://github.com/FluxML/Zygote.jl/issues/866
x = reshape(x, sizein..., ntuple(i->r, d)..., cout, n)
perm = [d+1:2d 1:d]' |> vec # = [d+1, 1, d+2, 2, ..., 2d, d]
perm = hcat(d+1:2d, 1:d) |> transpose |> vec # = [d+1, 1, d+2, 2, ..., 2d, d]
x = permutedims(x, (perm..., 2d+1, 2d+2))
return reshape(x, ((r .* sizein)..., cout, n))
return reshape(x, map(s -> s*r, sizein)..., cout, n)
end