Skip to content

Commit 147f03d

Browse files
change name; remove cat
1 parent caca2c8 commit 147f03d

File tree

2 files changed

+35
-20
lines changed

2 files changed

+35
-20
lines changed

src/upsample.jl

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
export bilinear_upsample, ∇bilinear_upsample, pixel_shuffle
1+
export upsample_bilinear, ∇upsample_bilinear, pixel_shuffle
22

33
"""
4-
bilinear_upsample(x::AbstractArray{<:Number,4}, k::NTuple{2,Int})
4+
upsample_bilinear(x::AbstractArray{<:Number,4}, k::NTuple{2,Int})
55
66
Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `k`,
77
using bilinear interpolation.
@@ -13,7 +13,7 @@ The interpolation grid is identical to the one used by `imresize` from `Images.j
1313
1414
Currently only 2d upsampling is supported.
1515
"""
16-
function bilinear_upsample(x::AbstractArray{T,4}, k::NTuple{2,Int}) where T
16+
function upsample_bilinear(x::AbstractArray{T,4}, k::NTuple{2,Int}) where T
1717
# This function is gpu friendly
1818

1919
imgsize = size(x)
@@ -83,7 +83,7 @@ end
8383

8484

8585
"""
86-
bilinear_upsample(Δ::AbstractArray{<:Number,4}, k::NTuple{2,Int})
86+
upsample_bilinear(Δ::AbstractArray{<:Number,4}, k::NTuple{2,Int})
8787
8888
# Arguments
8989
- `Δ`: array that has been upsampled using the upsample factors in `k`
@@ -93,13 +93,13 @@ end
9393
9494
# Explanation
9595
96-
Custom adjoint for [`bilinear_upsample`](@ref).
96+
Custom adjoint for [`upsample_bilinear`](@ref).
9797
The adjoint of upsampling is a downsampling operation, which
9898
in this implementation is performed using `NNlib.conv` in combination with a downsampling kernel based on the
9999
upsampling factors. Because of the zero-padding during convolution, the values at the boundary are polluted by edge-effects,
100100
which have been corrected for manually.
101101
"""
102-
function bilinear_upsample::AbstractArray{<:Number, 4}, k::NTuple{2,Int})
102+
function upsample_bilinear::AbstractArray{<:Number, 4}, k::NTuple{2,Int})
103103
# This function is gpu friendly
104104

105105
# Be more efficient on some corner cases
@@ -125,7 +125,12 @@ function ∇bilinear_upsample(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int})
125125
pad = (floor(Int, k[1]//2), floor(Int, k[2]//2))
126126
stride = k
127127

128-
weight = cat(fill(kern, n_chan)..., dims=(3,4))
128+
weight = similar(Δ, eltype(Δ), (size(kern)..., n_chan, n_chan))
129+
weight .= 0
130+
for i in 1:n_chan
131+
weight[:,:,i,i] .= kern
132+
end
133+
# weight = cat(fill(kern, n_chan)..., dims=(3,4)) # slow
129134
dx = conv(Δ, weight, pad=pad, stride=stride)
130135

131136
# Still have to fix edge effects due to zero-padding of convolution,
@@ -138,7 +143,12 @@ function ∇bilinear_upsample(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int})
138143
kern1 = kern[1:nextras[1],:]
139144
pad1 = (0, pad[2])
140145
stride1 = (1, stride[2])
141-
weight1 = cat(fill(kern1, n_chan)..., dims=(3,4))
146+
weight1 = similar(Δ, eltype(Δ), (size(kern1)..., n_chan, n_chan))
147+
weight1 .= 0
148+
for i in 1:n_chan
149+
weight1[:,:,i,i] .= kern1
150+
end
151+
# weight1 = cat(fill(kern1, n_chan)..., dims=(3,4)) # slow
142152
dx[[1],:,:,:] .+= conv(Δ[1:nextras[1],:,:,:], weight1, pad=pad1, stride=stride1)
143153
weight1 .= weight1[end:-1:1,:,:,:]
144154
dx[[end],:,:,:] .+= conv(Δ[end-nextras[1]+1:end,:,:,:], weight1, pad=pad1, stride=stride1)
@@ -154,7 +164,12 @@ function ∇bilinear_upsample(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int})
154164
kern2 = kern[:,1:nextras[2]]
155165
pad2 = (pad[1], 0)
156166
stride2 = (stride[1], 1)
157-
weight2 = cat(fill(kern2, n_chan)..., dims=(3,4))
167+
weight2 = similar(Δ, eltype(Δ), (size(kern2)..., n_chan, n_chan))
168+
weight2 .= 0
169+
for i in 1:n_chan
170+
weight2[:,:,i,i] .= kern2
171+
end
172+
# weight2 = cat(fill(kern2, n_chan)..., dims=(3,4)) # slow
158173

159174
yy = conv(Δ[:,1:nextras[2],:,:], weight2, pad=pad2, stride=stride2)
160175
dx[:,[1],:,:] .+= conv(Δ[:,1:nextras[2],:,:], weight2, pad=pad2, stride=stride2)
@@ -199,12 +214,12 @@ function get_downsamplekernel(Δ, n::Int)
199214
return kernel
200215
end
201216

202-
function ChainRulesCore.rrule(::typeof(bilinear_upsample), x, k)
203-
Ω = bilinear_upsample(x, k)
204-
function bilinear_upsample_pullback(Δ)
205-
(NO_FIELDS, bilinear_upsample(Δ, k), DoesNotExist())
217+
function ChainRulesCore.rrule(::typeof(upsample_bilinear), x, k)
218+
Ω = upsample_bilinear(x, k)
219+
function upsample_bilinear_pullback(Δ)
220+
(NO_FIELDS, upsample_bilinear(Δ, k), DoesNotExist())
206221
end
207-
return Ω, bilinear_upsample_pullback
222+
return Ω, upsample_bilinear_pullback
208223
end
209224

210225

test/upsample.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@testset "bilinear_upsample 2d" begin
1+
@testset "upsample_bilinear 2d" begin
22
x = reshape(Float32[1. 2.; 3. 4.], (2,2,1,1))
33
y_true = [1//1 5//4 7//4 2//1;
44
1//1 5//4 7//4 2//1;
@@ -7,21 +7,21 @@
77
3//1 13//4 15//4 4//1;
88
3//1 13//4 15//4 4//1][:,:,:,:]
99

10-
y = bilinear_upsample(x, (3, 2))
10+
y = upsample_bilinear(x, (3, 2))
1111
@test size(y) == size(y_true)
1212
@test eltype(y) == Float32
1313
@test y y_true
1414

15-
gradtest(x->bilinear_upsample(x, (3, 2)), x, atol=1e-4)
15+
gradtest(x->upsample_bilinear(x, (3, 2)), x, atol=1e-4)
1616

1717
if CUDA.has_cuda()
18-
y = bilinear_upsample(x |> cu, (3, 2))
18+
y = upsample_bilinear(x |> cu, (3, 2))
1919
@test y isa CuArray
2020
@test Array(y) y_true
21-
g_gpu = Zygote.gradient(x -> sum(sin.(bilinear_upsample(x, (3, 2))))
21+
g_gpu = Zygote.gradient(x -> sum(sin.(upsample_bilinear(x, (3, 2))))
2222
, x |> cu)[1]
2323
@test g_gpu isa CuArray
24-
g_cpu = Zygote.gradient(x -> sum(sin.(bilinear_upsample(x, (3, 2))))
24+
g_cpu = Zygote.gradient(x -> sum(sin.(upsample_bilinear(x, (3, 2))))
2525
, x)[1]
2626
@test Array(g_cpu) g_cpu atol=1e-4
2727
end

0 commit comments

Comments
 (0)