@@ -29,3 +29,161 @@ NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number,
29
29
30
30
Base. unsafe_convert (:: Type{CuPtr{T}} , A:: NNlib.BatchedAdjOrTrans{T} ) where {T} =
31
31
Base. unsafe_convert (CuPtr{T}, parent (A))
32
+
33
+
34
+ #
35
+ # Upsampling
36
+ #
37
+
38
+ # GPU based bilinear upsampling including its gradient
39
+ #
40
+ # Based on the Caffe2 implementation at:
41
+ # The code is a translation from the following files:
42
+ # - https://github.com/pytorch/pytorch/blob/v1.8.0-rc1/caffe2/operators/upsample_op.cu
43
+ # - https://github.com/pytorch/pytorch/blob/v1.8.0-rc1/caffe2/core/common_gpu.h
44
+ #
45
+ # Copyright (c) 2016-2021 Facebook Inc.
46
+ # Copyright (c) 2015 Google Inc.
47
+ # Copyright (c) 2015 Yangqing Jia
48
+ # Copyright 2019-2020 Kakao Brain
49
+ #
50
+ # All rights reserved.
51
+ #
52
+ # Redistribution and use in source and binary forms, with or without modification, are
53
+ # permitted provided that the following conditions are met:
54
+ #
55
+ # 1. Redistributions of source code must retain the above copyright notice, this list of
56
+ # conditions and the following disclaimer.
57
+ #
58
+ # 2. Redistributions in binary form must reproduce the above copyright notice, this list of
59
+ # conditions and the following disclaimer in the documentation and/or other materials
60
+ # provided with the distribution.
61
+ #
62
+ # 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America and
63
+ # IDIAP Research Institute nor the names of its contributors may be used to endorse or
64
+ # promote products derived from this software without specific prior written permission.
65
+ #
66
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
67
+ # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
68
+ # MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
69
+ # COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
70
+ # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
71
+ # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
72
+ # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
73
+ # TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
74
+ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
75
+
76
+ # Forward and backward pass have been tested to produce the same output
77
+ # as pytorch with align_corners=True - it works modulo bit noise.
78
+
79
+ function upsample_bilinear_whcn_kernel! (n_elem, rheight, rwidth, x, y)
80
+ index = (threadIdx (). x - 1 ) + (blockIdx (). x - 1 ) * blockDim (). x
81
+
82
+ if index < n_elem
83
+ in_w, in_h, channels, batchsize = size (x)
84
+ out_w, out_h, _, _ = size (y)
85
+
86
+ ow = index % out_w
87
+ oh = index ÷ out_w
88
+
89
+ real_index = rheight* oh
90
+ ih0 = Base. floor (Int, real_index)
91
+ offset = (ih0 < in_h- 1 ) ? 1 : 0
92
+ ih1 = ih0 + offset + 1
93
+ h1lambda = real_index - ih0
94
+ h0lambda = 1 - h1lambda
95
+ ih0 += 1
96
+
97
+ real_index = rwidth* ow
98
+ iw0 = Base. floor (Int, real_index)
99
+ offset = (iw0 < in_w- 1 ) ? 1 : 0
100
+ iw1 = iw0 + offset + 1
101
+ w1lambda = real_index - iw0
102
+ w0lambda = 1 - w1lambda
103
+ iw0 += 1
104
+
105
+ @inbounds for n in 1 : batchsize
106
+ for c in 1 : channels
107
+ val = h0lambda * (w0lambda * x[iw0, ih0, c, n] + # h0 * w0 * i00
108
+ w1lambda * x[iw1, ih0, c, n]) + # h0 * w1 * i01
109
+ h1lambda * (w0lambda * x[iw0, ih1, c, n] + # h1 * w0 * i10
110
+ w1lambda * x[iw1, ih1, c, n]) # h1 * w1 * i11
111
+ y[ow+ 1 , oh+ 1 , c, n] = val
112
+ end
113
+ end
114
+ end
115
+ return nothing
116
+ end
117
+
118
+ # Δ is the gradient backpropagated from downstream layers
119
+ function ∇upsample_bilinear_whcn_kernel! (n_elem, rheight, rwidth, Δ, dx)
120
+ index = (threadIdx (). x - 1 ) + (blockIdx (). x - 1 ) * blockDim (). x
121
+
122
+ if index < n_elem
123
+ in_width, in_height, channels, batchsize = size (Δ)
124
+ out_width, out_height, _, _ = size (dx)
125
+
126
+ iw = index % in_width
127
+ ih = index ÷ in_width
128
+
129
+ # Compute Y axis lambdas
130
+ real_index_h = rheight* ih
131
+ oh0 = Base. floor (Int, real_index_h)
132
+ offset = (oh0 < out_height- 1 ) ? 1 : 0
133
+ oh1 = oh0 + offset + 1
134
+ h1lambda = real_index_h - oh0
135
+ h0lambda = 1 - h1lambda
136
+ oh0 += 1
137
+
138
+ # # Compute X axis lambdas
139
+ real_index_w = rwidth * iw
140
+ ow0 = Base. floor (Int, real_index_w)
141
+ offset = (ow0 < out_width - 1 ) ? 1 : 0
142
+ ow1 = ow0 + offset + 1
143
+ w1lambda = real_index_w - ow0
144
+ w0lambda = 1 - w1lambda
145
+ ow0 += 1
146
+
147
+ @inbounds for n in 1 : batchsize
148
+ for c in 1 : channels
149
+ val = Δ[iw+ 1 , ih+ 1 , c, n]
150
+ @atomic dx[ow0, oh0, c, n] += h0lambda * w0lambda * val
151
+ @atomic dx[ow1, oh0, c, n] += h0lambda * w1lambda * val
152
+ @atomic dx[ow0, oh1, c, n] += h1lambda * w0lambda * val
153
+ @atomic dx[ow1, oh1, c, n] += h1lambda * w1lambda * val
154
+ end
155
+ end
156
+ end # if
157
+ return nothing
158
+ end
159
+
160
+ function NNlib. upsample_bilinear_whcn! (y:: CuArray{T,4} , x:: CuArray{T,4} ) where T
161
+ w,h,c,n = size (x)
162
+ out_w, out_h = (size (y,1 ), size (y,2 ))
163
+
164
+ out_size = out_h* out_w
165
+ rheight = T ((h- 1 )/ (out_h- 1 ))
166
+ rwidth = T ((w- 1 )/ (out_w- 1 ))
167
+
168
+ kernel = @cuda launch= false upsample_bilinear_whcn_kernel! (out_size, rheight, rwidth, x, y)
169
+ config = launch_configuration (kernel. fun; max_threads= 256 )
170
+ threads = Base. min (out_size, config. threads)
171
+ blocks = cld (out_size, threads)
172
+ kernel (out_size, rheight, rwidth, x, y; threads= threads, blocks= blocks)
173
+ return y
174
+ end
175
+
176
+ function NNlib. ∇upsample_bilinear_whcn! (dx:: CuArray{T,4} , Δ:: CuArray{T,4} ) where T
177
+ w,h,c,n = Base. size (Δ)
178
+ out_w, out_h = (size (dx, 1 ), size (dx, 2 ))
179
+ in_size = h* w
180
+ rheight = T ((out_h- 1 )/ (h- 1 )) # reversed compared to forward pass
181
+ rwidth = T ((out_w- 1 )/ (w- 1 ))
182
+
183
+ kernel = @cuda launch= false ∇upsample_bilinear_whcn_kernel! (in_size, rheight, rwidth, Δ, dx)
184
+ config = launch_configuration (kernel. fun; max_threads= 256 )
185
+ threads = Base. min (in_size, config. threads)
186
+ blocks = cld (in_size, threads)
187
+ kernel (in_size, rheight, rwidth, Δ, dx; threads= threads, blocks= blocks)
188
+ return dx
189
+ end
0 commit comments