1
+ using NeuralNetworkNET . APIs . Structs ;
2
+ using System ;
3
+ using System . Threading . Tasks ;
4
+ using NeuralNetworkNET . Extensions ;
5
+
6
+ namespace NeuralNetworkNET . cpuDNN
7
+ {
8
+ /// <summary>
9
+ /// A static class with a collection of pooling extension methods
10
+ /// </summary>
11
+ public static partial class CpuDnn
12
+ {
13
+ /// <summary>
14
+ /// Executes the forward pass on a max pooling layer with a 2x2 window and a stride of 2
15
+ /// </summary>
16
+ /// <param name="x">The input <see cref="Tensor"/> to pool</param>
17
+ /// <param name="xInfo">The info on the input <see cref="Tensor"/></param>
18
+ /// <param name="y">The resulting pooled <see cref="Tensor"/></param>
19
+ public static unsafe void PoolingForward ( in Tensor x , in TensorInfo xInfo , in Tensor y )
20
+ {
21
+ int h = x . Entities , w = x . Length ;
22
+ if ( h < 1 || w < 1 ) throw new ArgumentException ( "The input matrix isn't valid" ) ;
23
+ int
24
+ depth = xInfo . Channels ,
25
+ imgSize = w % depth == 0 ? w / depth : throw new ArgumentException ( nameof ( x ) , "Invalid depth parameter for the input matrix" ) ,
26
+ imgAxis = imgSize . IntegerSquare ( ) ; // Size of an edge of one of the inner images per sample
27
+ if ( imgAxis * imgAxis != imgSize ) throw new ArgumentOutOfRangeException ( nameof ( x ) , "The size of the input matrix isn't valid" ) ;
28
+ int
29
+ poolAxis = imgAxis / 2 + ( imgAxis % 2 == 0 ? 0 : 1 ) ,
30
+ poolSize = poolAxis * poolAxis ,
31
+ poolFinalWidth = depth * poolSize ,
32
+ edge = imgAxis - 1 ;
33
+ if ( ! y . MatchShape ( h , poolFinalWidth ) ) throw new ArgumentException ( "The output tensor shape isn't valid" , nameof ( y ) ) ;
34
+
35
+ // Pooling kernel
36
+ float * px = x , py = y ;
37
+ void Kernel ( int sample )
38
+ {
39
+ int
40
+ sourceBaseOffset = sample * w ,
41
+ resultBaseOffset = sample * poolFinalWidth ;
42
+ for ( int z = 0 ; z < depth ; z ++ )
43
+ {
44
+ int
45
+ sourceZOffset = sourceBaseOffset + z * imgSize ,
46
+ resultZOffset = resultBaseOffset + z * poolSize ,
47
+ c = 0 ;
48
+ for ( int i = 0 ; i < imgAxis ; i += 2 )
49
+ {
50
+ int
51
+ sourceIOffset = sourceZOffset + i * imgAxis ,
52
+ resultXOffset = resultZOffset + c * poolAxis ,
53
+ r = 0 ;
54
+ if ( i == edge )
55
+ {
56
+ // Last row
57
+ for ( int j = 0 ; j < imgAxis ; j += 2 )
58
+ {
59
+ float max ;
60
+ if ( j == w - 1 ) max = px [ sourceIOffset + j ] ; // Last column
61
+ else
62
+ {
63
+ float
64
+ left = px [ sourceIOffset + j ] ,
65
+ right = px [ sourceIOffset + j + 1 ] ;
66
+ max = left > right ? left : right ;
67
+ }
68
+ py [ resultXOffset + r ++ ] = max ;
69
+ }
70
+ }
71
+ else
72
+ {
73
+ int sourceI_1Offset = sourceZOffset + ( i + 1 ) * imgAxis ;
74
+ for ( int j = 0 ; j < imgAxis ; j += 2 )
75
+ {
76
+ float max ;
77
+ if ( j == edge )
78
+ {
79
+ // Last column
80
+ float
81
+ up = px [ sourceIOffset + j ] ,
82
+ down = px [ sourceI_1Offset + j ] ;
83
+ max = up > down ? up : down ;
84
+ }
85
+ else
86
+ {
87
+ float
88
+ upLeft = px [ sourceIOffset + j ] ,
89
+ upRight = px [ sourceIOffset + j + 1 ] ,
90
+ downLeft = px [ sourceI_1Offset + j ] ,
91
+ downRight = px [ sourceI_1Offset + j + 1 ] ,
92
+ maxUp = upLeft > upRight ? upLeft : upRight ,
93
+ maxDown = downLeft > downRight ? downLeft : downRight ;
94
+ max = maxUp > maxDown ? maxUp : maxDown ;
95
+ }
96
+ py [ resultXOffset + r ++ ] = max ;
97
+ }
98
+ }
99
+ c ++ ;
100
+ }
101
+ }
102
+ }
103
+ Parallel . For ( 0 , h , Kernel ) . AssertCompleted ( ) ;
104
+ }
105
+
106
+ /// <summary>
107
+ /// Executes the backward pass on a max pooling layer with a 2x2 window and a stride of 2
108
+ /// </summary>
109
+ /// <param name="x">The original input <see cref="Tensor"/> used during the forward pass</param>
110
+ /// <param name="xInfo">The info on the input <see cref="Tensor"/></param>
111
+ /// <param name="dy">The output error for the current layer</param>
112
+ /// <param name="dx">The resulting backpropagated error</param>
113
+ public static unsafe void PoolingBackward ( in Tensor x , in TensorInfo xInfo , in Tensor dy , in Tensor dx )
114
+ {
115
+ // Prepare the result matrix
116
+ if ( ! dx . MatchShape ( x ) ) throw new ArgumentException ( "The result tensor must have the same shape as the input" , nameof ( dx ) ) ;
117
+ int n = x . Entities , l = x . Length ;
118
+ if ( n < 1 || l < 1 ) throw new ArgumentException ( "The input matrix isn't valid" ) ;
119
+ int
120
+ depth = xInfo . Channels ,
121
+ imgSize = l % depth == 0 ? l / depth : throw new ArgumentException ( nameof ( x ) , "Invalid depth parameter for the input matrix" ) ,
122
+ imgAxis = imgSize . IntegerSquare ( ) ; // Size of an edge of one of the inner images per sample
123
+ if ( imgAxis * imgAxis != imgSize ) throw new ArgumentOutOfRangeException ( nameof ( x ) , "The size of the input matrix isn't valid" ) ;
124
+ int
125
+ poolAxis = imgAxis / 2 + ( imgAxis % 2 == 0 ? 0 : 1 ) ,
126
+ poolSize = poolAxis * poolAxis ,
127
+ poolFinalWidth = depth * poolSize ,
128
+ edge = imgAxis - 1 ;
129
+ int
130
+ pn = dy . Entities ,
131
+ pl = dy . Length ;
132
+ if ( pn != n || pl != poolFinalWidth ) throw new ArgumentException ( "Invalid pooled matrix" , nameof ( dy ) ) ;
133
+
134
+ // Pooling kernel
135
+ float * px = x , pdy = dy , pdx = dx ;
136
+ void Kernel ( int sample )
137
+ {
138
+ int
139
+ sourceBaseOffset = sample * l ,
140
+ resultBaseOffset = sample * poolFinalWidth ;
141
+ for ( int z = 0 ; z < depth ; z ++ )
142
+ {
143
+ int
144
+ sourceZOffset = sourceBaseOffset + z * imgSize ,
145
+ resultZOffset = resultBaseOffset + z * poolSize ,
146
+ c = 0 ;
147
+ for ( int i = 0 ; i < imgAxis ; i += 2 )
148
+ {
149
+ int
150
+ sourceIOffset = sourceZOffset + i * imgAxis ,
151
+ resultXOffset = resultZOffset + c * poolAxis ,
152
+ r = 0 ;
153
+ if ( i == edge )
154
+ {
155
+ // Last row
156
+ for ( int j = 0 ; j < imgAxis ; j += 2 )
157
+ {
158
+ if ( j == l - 1 )
159
+ {
160
+ pdx [ sourceIOffset + j ] = pdy [ resultXOffset + r ++ ] ;
161
+ }
162
+ else
163
+ {
164
+ float
165
+ left = px [ sourceIOffset + j ] ,
166
+ right = px [ sourceIOffset + j + 1 ] ;
167
+ if ( left > right )
168
+ {
169
+ pdx [ sourceIOffset + j ] = pdy [ resultXOffset + r ++ ] ;
170
+ pdx [ sourceIOffset + j + 1 ] = 0 ;
171
+ }
172
+ else
173
+ {
174
+ pdx [ sourceIOffset + j + 1 ] = pdy [ resultXOffset + r ++ ] ;
175
+ pdx [ sourceIOffset + j ] = 0 ;
176
+ }
177
+ }
178
+ }
179
+ }
180
+ else
181
+ {
182
+ int sourceI_1Offset = sourceZOffset + ( i + 1 ) * imgAxis ;
183
+ for ( int j = 0 ; j < imgAxis ; j += 2 )
184
+ {
185
+ if ( j == edge )
186
+ {
187
+ // Last column
188
+ float
189
+ up = px [ sourceIOffset + j ] ,
190
+ down = px [ sourceI_1Offset + j ] ;
191
+ if ( up > down )
192
+ {
193
+ pdx [ sourceIOffset + j ] = pdy [ resultXOffset + r ++ ] ;
194
+ pdx [ sourceI_1Offset + j ] = 0 ;
195
+ }
196
+ else
197
+ {
198
+ pdx [ sourceI_1Offset + j ] = pdy [ resultXOffset + r ++ ] ;
199
+ pdx [ sourceIOffset + j ] = 0 ;
200
+ }
201
+ }
202
+ else
203
+ {
204
+ int offset = sourceIOffset + j ;
205
+ float
206
+ max = px [ offset ] ,
207
+ next = px [ sourceIOffset + j + 1 ] ;
208
+ if ( next > max )
209
+ {
210
+ max = next ;
211
+ pdx [ offset ] = 0 ;
212
+ offset = sourceIOffset + j + 1 ;
213
+ }
214
+ else pdx [ sourceIOffset + j + 1 ] = 0 ;
215
+ next = px [ sourceI_1Offset + j ] ;
216
+ if ( next > max )
217
+ {
218
+ max = next ;
219
+ pdx [ offset ] = 0 ;
220
+ offset = sourceI_1Offset + j ;
221
+ }
222
+ else pdx [ sourceI_1Offset + j ] = 0 ;
223
+ next = px [ sourceI_1Offset + j + 1 ] ;
224
+ if ( next > max )
225
+ {
226
+ pdx [ offset ] = 0 ;
227
+ offset = sourceI_1Offset + j + 1 ;
228
+ }
229
+ else pdx [ sourceI_1Offset + j + 1 ] = 0 ;
230
+ pdx [ offset ] = pdy [ resultXOffset + r ++ ] ;
231
+ }
232
+ }
233
+ }
234
+ c ++ ;
235
+ }
236
+ }
237
+ }
238
+ Parallel . For ( 0 , n , Kernel ) . AssertCompleted ( ) ;
239
+ }
240
+ }
241
+ }
0 commit comments