1
- export bilinear_upsample , ∇bilinear_upsample , pixel_shuffle
1
+ export upsample_bilinear , ∇upsample_bilinear , pixel_shuffle
2
2
3
3
"""
4
- bilinear_upsample (x::AbstractArray{<:Number,4}, k::NTuple{2,Int})
4
+ upsample_bilinear (x::AbstractArray{<:Number,4}, k::NTuple{2,Int})
5
5
6
6
Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `k`,
7
7
using bilinear interpolation.
@@ -13,7 +13,7 @@ The interpolation grid is identical to the one used by `imresize` from `Images.j
13
13
14
14
Currently only 2d upsampling is supported.
15
15
"""
16
- function bilinear_upsample (x:: AbstractArray{T,4} , k:: NTuple{2,Int} ) where T
16
+ function upsample_bilinear (x:: AbstractArray{T,4} , k:: NTuple{2,Int} ) where T
17
17
# This function is gpu friendly
18
18
19
19
imgsize = size (x)
83
83
84
84
85
85
"""
86
- ∇bilinear_upsample (Δ::AbstractArray{<:Number,4}, k::NTuple{2,Int})
86
+ ∇upsample_bilinear (Δ::AbstractArray{<:Number,4}, k::NTuple{2,Int})
87
87
88
88
# Arguments
89
89
- `Δ`: array that has been upsampled using the upsample factors in `k`
93
93
94
94
# Explanation
95
95
96
- Custom adjoint for [`bilinear_upsample `](@ref).
96
+ Custom adjoint for [`upsample_bilinear `](@ref).
97
97
The adjoint of upsampling is a downsampling operation, which
98
98
in this implementation is performed using `NNlib.conv` in combination with a downsampling kernel based on the
99
99
upsampling factors. Because of the zero-padding during convolution, the values at the boundary are polluted by edge-effects,
100
100
which have been corrected for manually.
101
101
"""
102
- function ∇bilinear_upsample (Δ:: AbstractArray{<:Number, 4} , k:: NTuple{2,Int} )
102
+ function ∇upsample_bilinear (Δ:: AbstractArray{<:Number, 4} , k:: NTuple{2,Int} )
103
103
# This function is gpu friendly
104
104
105
105
# Be more efficient on some corner cases
@@ -125,7 +125,12 @@ function ∇bilinear_upsample(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int})
125
125
pad = (floor (Int, k[1 ]// 2 ), floor (Int, k[2 ]// 2 ))
126
126
stride = k
127
127
128
- weight = cat (fill (kern, n_chan)... , dims= (3 ,4 ))
128
+ weight = similar (Δ, eltype (Δ), (size (kern)... , n_chan, n_chan))
129
+ weight .= 0
130
+ for i in 1 : n_chan
131
+ weight[:,:,i,i] .= kern
132
+ end
133
+ # weight = cat(fill(kern, n_chan)..., dims=(3,4)) # slow
129
134
dx = conv (Δ, weight, pad= pad, stride= stride)
130
135
131
136
# Still have to fix edge effects due to zero-padding of convolution,
@@ -138,7 +143,12 @@ function ∇bilinear_upsample(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int})
138
143
kern1 = kern[1 : nextras[1 ],:]
139
144
pad1 = (0 , pad[2 ])
140
145
stride1 = (1 , stride[2 ])
141
- weight1 = cat (fill (kern1, n_chan)... , dims= (3 ,4 ))
146
+ weight1 = similar (Δ, eltype (Δ), (size (kern1)... , n_chan, n_chan))
147
+ weight1 .= 0
148
+ for i in 1 : n_chan
149
+ weight1[:,:,i,i] .= kern1
150
+ end
151
+ # weight1 = cat(fill(kern1, n_chan)..., dims=(3,4)) # slow
142
152
dx[[1 ],:,:,:] .+ = conv (Δ[1 : nextras[1 ],:,:,:], weight1, pad= pad1, stride= stride1)
143
153
weight1 .= weight1[end : - 1 : 1 ,:,:,:]
144
154
dx[[end ],:,:,:] .+ = conv (Δ[end - nextras[1 ]+ 1 : end ,:,:,:], weight1, pad= pad1, stride= stride1)
@@ -154,7 +164,12 @@ function ∇bilinear_upsample(Δ::AbstractArray{<:Number, 4}, k::NTuple{2,Int})
154
164
kern2 = kern[:,1 : nextras[2 ]]
155
165
pad2 = (pad[1 ], 0 )
156
166
stride2 = (stride[1 ], 1 )
157
- weight2 = cat (fill (kern2, n_chan)... , dims= (3 ,4 ))
167
+ weight2 = similar (Δ, eltype (Δ), (size (kern2)... , n_chan, n_chan))
168
+ weight2 .= 0
169
+ for i in 1 : n_chan
170
+ weight2[:,:,i,i] .= kern2
171
+ end
172
+ # weight2 = cat(fill(kern2, n_chan)..., dims=(3,4)) # slow
158
173
159
174
yy = conv (Δ[:,1 : nextras[2 ],:,:], weight2, pad= pad2, stride= stride2)
160
175
dx[:,[1 ],:,:] .+ = conv (Δ[:,1 : nextras[2 ],:,:], weight2, pad= pad2, stride= stride2)
@@ -199,12 +214,12 @@ function get_downsamplekernel(Δ, n::Int)
199
214
return kernel
200
215
end
201
216
202
- function ChainRulesCore. rrule (:: typeof (bilinear_upsample ), x, k)
203
- Ω = bilinear_upsample (x, k)
204
- function bilinear_upsample_pullback (Δ)
205
- (NO_FIELDS, ∇bilinear_upsample (Δ, k), DoesNotExist ())
217
+ function ChainRulesCore. rrule (:: typeof (upsample_bilinear ), x, k)
218
+ Ω = upsample_bilinear (x, k)
219
+ function upsample_bilinear_pullback (Δ)
220
+ (NO_FIELDS, ∇upsample_bilinear (Δ, k), DoesNotExist ())
206
221
end
207
- return Ω, bilinear_upsample_pullback
222
+ return Ω, upsample_bilinear_pullback
208
223
end
209
224
210
225
0 commit comments