Skip to content

Commit b3d41e3

Browse files
committed
Pooling code refactored (WIP)
1 parent 05f31b1 commit b3d41e3

File tree

1 file changed

+241
-0
lines changed

1 file changed

+241
-0
lines changed
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
using NeuralNetworkNET.APIs.Structs;
2+
using System;
3+
using System.Threading.Tasks;
4+
using NeuralNetworkNET.Extensions;
5+
6+
namespace NeuralNetworkNET.cpuDNN
7+
{
8+
/// <summary>
9+
/// A static class with a collection of pooling extension methods
10+
/// </summary>
11+
public static partial class CpuDnn
12+
{
13+
/// <summary>
14+
/// Executes the forward pass on a max pooling layer with a 2x2 window and a stride of 2
15+
/// </summary>
16+
/// <param name="x">The input <see cref="Tensor"/> to pool</param>
17+
/// <param name="xInfo">The info on the input <see cref="Tensor"/></param>
18+
/// <param name="y">The resulting pooled <see cref="Tensor"/></param>
19+
public static unsafe void PoolingForward(in Tensor x, in TensorInfo xInfo, in Tensor y)
20+
{
21+
int h = x.Entities, w = x.Length;
22+
if (h < 1 || w < 1) throw new ArgumentException("The input matrix isn't valid");
23+
int
24+
depth = xInfo.Channels,
25+
imgSize = w % depth == 0 ? w / depth : throw new ArgumentException(nameof(x), "Invalid depth parameter for the input matrix"),
26+
imgAxis = imgSize.IntegerSquare(); // Size of an edge of one of the inner images per sample
27+
if (imgAxis * imgAxis != imgSize) throw new ArgumentOutOfRangeException(nameof(x), "The size of the input matrix isn't valid");
28+
int
29+
poolAxis = imgAxis / 2 + (imgAxis % 2 == 0 ? 0 : 1),
30+
poolSize = poolAxis * poolAxis,
31+
poolFinalWidth = depth * poolSize,
32+
edge = imgAxis - 1;
33+
if (!y.MatchShape(h, poolFinalWidth)) throw new ArgumentException("The output tensor shape isn't valid", nameof(y));
34+
35+
// Pooling kernel
36+
float* px = x, py = y;
37+
void Kernel(int sample)
38+
{
39+
int
40+
sourceBaseOffset = sample * w,
41+
resultBaseOffset = sample * poolFinalWidth;
42+
for (int z = 0; z < depth; z++)
43+
{
44+
int
45+
sourceZOffset = sourceBaseOffset + z * imgSize,
46+
resultZOffset = resultBaseOffset + z * poolSize,
47+
c = 0;
48+
for (int i = 0; i < imgAxis; i += 2)
49+
{
50+
int
51+
sourceIOffset = sourceZOffset + i * imgAxis,
52+
resultXOffset = resultZOffset + c * poolAxis,
53+
r = 0;
54+
if (i == edge)
55+
{
56+
// Last row
57+
for (int j = 0; j < imgAxis; j += 2)
58+
{
59+
float max;
60+
if (j == w - 1) max = px[sourceIOffset + j]; // Last column
61+
else
62+
{
63+
float
64+
left = px[sourceIOffset + j],
65+
right = px[sourceIOffset + j + 1];
66+
max = left > right ? left : right;
67+
}
68+
py[resultXOffset + r++] = max;
69+
}
70+
}
71+
else
72+
{
73+
int sourceI_1Offset = sourceZOffset + (i + 1) * imgAxis;
74+
for (int j = 0; j < imgAxis; j += 2)
75+
{
76+
float max;
77+
if (j == edge)
78+
{
79+
// Last column
80+
float
81+
up = px[sourceIOffset + j],
82+
down = px[sourceI_1Offset + j];
83+
max = up > down ? up : down;
84+
}
85+
else
86+
{
87+
float
88+
upLeft = px[sourceIOffset + j],
89+
upRight = px[sourceIOffset + j + 1],
90+
downLeft = px[sourceI_1Offset + j],
91+
downRight = px[sourceI_1Offset + j + 1],
92+
maxUp = upLeft > upRight ? upLeft : upRight,
93+
maxDown = downLeft > downRight ? downLeft : downRight;
94+
max = maxUp > maxDown ? maxUp : maxDown;
95+
}
96+
py[resultXOffset + r++] = max;
97+
}
98+
}
99+
c++;
100+
}
101+
}
102+
}
103+
Parallel.For(0, h, Kernel).AssertCompleted();
104+
}
105+
106+
/// <summary>
107+
/// Executes the backward pass on a max pooling layer with a 2x2 window and a stride of 2
108+
/// </summary>
109+
/// <param name="x">The original input <see cref="Tensor"/> used during the forward pass</param>
110+
/// <param name="xInfo">The info on the input <see cref="Tensor"/></param>
111+
/// <param name="dy">The output error for the current layer</param>
112+
/// <param name="dx">The resulting backpropagated error</param>
113+
public static unsafe void PoolingBackward(in Tensor x, in TensorInfo xInfo, in Tensor dy, in Tensor dx)
114+
{
115+
// Prepare the result matrix
116+
if (!dx.MatchShape(x)) throw new ArgumentException("The result tensor must have the same shape as the input", nameof(dx));
117+
int n = x.Entities, l = x.Length;
118+
if (n < 1 || l < 1) throw new ArgumentException("The input matrix isn't valid");
119+
int
120+
depth = xInfo.Channels,
121+
imgSize = l % depth == 0 ? l / depth : throw new ArgumentException(nameof(x), "Invalid depth parameter for the input matrix"),
122+
imgAxis = imgSize.IntegerSquare(); // Size of an edge of one of the inner images per sample
123+
if (imgAxis * imgAxis != imgSize) throw new ArgumentOutOfRangeException(nameof(x), "The size of the input matrix isn't valid");
124+
int
125+
poolAxis = imgAxis / 2 + (imgAxis % 2 == 0 ? 0 : 1),
126+
poolSize = poolAxis * poolAxis,
127+
poolFinalWidth = depth * poolSize,
128+
edge = imgAxis - 1;
129+
int
130+
pn = dy.Entities,
131+
pl = dy.Length;
132+
if (pn != n || pl != poolFinalWidth) throw new ArgumentException("Invalid pooled matrix", nameof(dy));
133+
134+
// Pooling kernel
135+
float* px = x, pdy = dy, pdx = dx;
136+
void Kernel(int sample)
137+
{
138+
int
139+
sourceBaseOffset = sample * l,
140+
resultBaseOffset = sample * poolFinalWidth;
141+
for (int z = 0; z < depth; z++)
142+
{
143+
int
144+
sourceZOffset = sourceBaseOffset + z * imgSize,
145+
resultZOffset = resultBaseOffset + z * poolSize,
146+
c = 0;
147+
for (int i = 0; i < imgAxis; i += 2)
148+
{
149+
int
150+
sourceIOffset = sourceZOffset + i * imgAxis,
151+
resultXOffset = resultZOffset + c * poolAxis,
152+
r = 0;
153+
if (i == edge)
154+
{
155+
// Last row
156+
for (int j = 0; j < imgAxis; j += 2)
157+
{
158+
if (j == l - 1)
159+
{
160+
pdx[sourceIOffset + j] = pdy[resultXOffset + r++];
161+
}
162+
else
163+
{
164+
float
165+
left = px[sourceIOffset + j],
166+
right = px[sourceIOffset + j + 1];
167+
if (left > right)
168+
{
169+
pdx[sourceIOffset + j] = pdy[resultXOffset + r++];
170+
pdx[sourceIOffset + j + 1] = 0;
171+
}
172+
else
173+
{
174+
pdx[sourceIOffset + j + 1] = pdy[resultXOffset + r++];
175+
pdx[sourceIOffset + j] = 0;
176+
}
177+
}
178+
}
179+
}
180+
else
181+
{
182+
int sourceI_1Offset = sourceZOffset + (i + 1) * imgAxis;
183+
for (int j = 0; j < imgAxis; j += 2)
184+
{
185+
if (j == edge)
186+
{
187+
// Last column
188+
float
189+
up = px[sourceIOffset + j],
190+
down = px[sourceI_1Offset + j];
191+
if (up > down)
192+
{
193+
pdx[sourceIOffset + j] = pdy[resultXOffset + r++];
194+
pdx[sourceI_1Offset + j] = 0;
195+
}
196+
else
197+
{
198+
pdx[sourceI_1Offset + j] = pdy[resultXOffset + r++];
199+
pdx[sourceIOffset + j] = 0;
200+
}
201+
}
202+
else
203+
{
204+
int offset = sourceIOffset + j;
205+
float
206+
max = px[offset],
207+
next = px[sourceIOffset + j + 1];
208+
if (next > max)
209+
{
210+
max = next;
211+
pdx[offset] = 0;
212+
offset = sourceIOffset + j + 1;
213+
}
214+
else pdx[sourceIOffset + j + 1] = 0;
215+
next = px[sourceI_1Offset + j];
216+
if (next > max)
217+
{
218+
max = next;
219+
pdx[offset] = 0;
220+
offset = sourceI_1Offset + j;
221+
}
222+
else pdx[sourceI_1Offset + j] = 0;
223+
next = px[sourceI_1Offset + j + 1];
224+
if (next > max)
225+
{
226+
pdx[offset] = 0;
227+
offset = sourceI_1Offset + j + 1;
228+
}
229+
else pdx[sourceI_1Offset + j + 1] = 0;
230+
pdx[offset] = pdy[resultXOffset + r++];
231+
}
232+
}
233+
}
234+
c++;
235+
}
236+
}
237+
}
238+
Parallel.For(0, n, Kernel).AssertCompleted();
239+
}
240+
}
241+
}

0 commit comments

Comments
 (0)