7
7
8
8
namespace NeuralNetworkNET . cpuDNN
9
9
{
10
+ /// <summary>
11
+ /// A static class that contains primitives to implement a CNN running on CPU
12
+ /// </summary>
10
13
public static partial class CpuDnn
11
14
{
12
15
#region Activation
13
16
17
+ /// <summary>
18
+ /// Executes the input activation function on the target <see cref="Tensor"/>
19
+ /// </summary>
20
+ /// <param name="x">The layer input <see cref="Tensor"/></param>
21
+ /// <param name="f">The activation function to apply to the input</param>
22
+ /// <param name="y">The output <see cref="Tensor"/> - it can be the same as the input</param>
14
23
public static unsafe void ActivationForward ( in Tensor x , [ NotNull ] ActivationFunction f , in Tensor y )
15
24
{
16
25
// Setup
@@ -31,15 +40,22 @@ void Kernel(int i)
31
40
Parallel . For ( 0 , n , Kernel ) . AssertCompleted ( ) ;
32
41
}
33
42
34
- public static unsafe void ActivationBackward ( in Tensor x , in Tensor y , [ NotNull ] ActivationFunction f_ , in Tensor dx )
43
+ /// <summary>
44
+ /// Executes the backward activation function on the target <see cref="Tensor"/>, with the given error delta
45
+ /// </summary>
46
+ /// <param name="x">The activity on the input layer</param>
47
+ /// <param name="dy">The current error delta to backpropagate</param>
48
+ /// <param name="f_">The derivative of the activation function used in the forward pass</param>
49
+ /// <param name="dx">The resulting input error delta - it can be the same as the input <see cref="Tensor"/></param>
50
+ public static unsafe void ActivationBackward ( in Tensor x , in Tensor dy , [ NotNull ] ActivationFunction f_ , in Tensor dx )
35
51
{
36
52
// Check
37
- if ( ! y . MatchShape ( x ) ) throw new ArgumentException ( "The input tensors must have the same shape" , nameof ( y ) ) ;
38
- if ( ! dx . MatchShape ( x ) ) throw new ArgumentException ( "The output tensor must have the same shape as the input" , nameof ( y ) ) ;
53
+ if ( ! dy . MatchShape ( x ) ) throw new ArgumentException ( "The input tensors must have the same shape" , nameof ( dy ) ) ;
54
+ if ( ! dx . MatchShape ( x ) ) throw new ArgumentException ( "The output tensor must have the same shape as the input" , nameof ( dy ) ) ;
39
55
int
40
56
n = x . Entities ,
41
57
l = x . Length ;
42
- float * px = x , py = y , pdx = dx ;
58
+ float * px = x , pdy = dy , pdx = dx ;
43
59
44
60
// Loop in parallel
45
61
void Kernel ( int i )
@@ -48,7 +64,7 @@ void Kernel(int i)
48
64
for ( int j = 0 ; j < l ; j ++ )
49
65
{
50
66
int target = offset + j ;
51
- pdx [ target ] = f_ ( px [ target ] ) * py [ target ] ;
67
+ pdx [ target ] = f_ ( px [ target ] ) * pdy [ target ] ;
52
68
}
53
69
}
54
70
Parallel . For ( 0 , n , Kernel ) . AssertCompleted ( ) ;
@@ -58,6 +74,13 @@ void Kernel(int i)
58
74
59
75
#region Fully connected
60
76
77
+ /// <summary>
78
+ /// Executes the forward pass on a fully connected layer
79
+ /// </summary>
80
+ /// <param name="x">The input <see cref="Tensor"/> to process</param>
81
+ /// <param name="w">The layer weights</param>
82
+ /// <param name="b">The layer biases</param>
83
+ /// <param name="y">The output <see cref="Tensor"/> for the current layer</param>
61
84
public static unsafe void FullyConnectedForward ( in Tensor x , in Tensor w , in Tensor b , in Tensor y )
62
85
{
63
86
// Initialize the parameters and the result matrix
@@ -89,6 +112,14 @@ void Kernel(int i)
89
112
Parallel . For ( 0 , h , Kernel ) . AssertCompleted ( ) ;
90
113
}
91
114
115
+ /// <summary>
116
+ /// Executes the backward pass on a fully connected layer
117
+ /// </summary>
118
+ /// <param name="x">The activity on the layer inputs</param>
119
+ /// <param name="w">The layer weights</param>
120
+ /// <param name="dy">The output error delta</param>
121
+ /// <param name="f_">The derivative of the activation function used in the forward pass</param>
122
+ /// <param name="dx">The resulting input error delta</param>
92
123
public static unsafe void FullyConnectedBackwardData ( in Tensor x , in Tensor w , in Tensor dy , [ NotNull ] ActivationFunction f_ , in Tensor dx )
93
124
{
94
125
if ( w . Entities != x . Length ) throw new ArgumentException ( "The weights tensor doesn't have a valid shape" , nameof ( w ) ) ;
@@ -127,6 +158,12 @@ void Kernel(int i)
127
158
wt . Free ( ) ;
128
159
}
129
160
161
+ /// <summary>
162
+ /// Executes the backward pass on a fully connected layer to calculate the gradient with respect to the weights
163
+ /// </summary>
164
+ /// <param name="x">The layer inputs</param>
165
+ /// <param name="dy">The layer output error delta</param>
166
+ /// <param name="dw">The resulting weights gradient <see cref="Tensor"/></param>
130
167
public static void FullyConnectedBackwardFilter ( in Tensor x , in Tensor dy , in Tensor dw )
131
168
{
132
169
if ( x . Entities != dy . Entities ) throw new ArgumentException ( "The input tensor doesn't match the number of samples from the delta" , nameof ( x ) ) ;
@@ -136,10 +173,15 @@ public static void FullyConnectedBackwardFilter(in Tensor x, in Tensor dy, in Te
136
173
xt . Free ( ) ;
137
174
}
138
175
176
+ /// <summary>
177
+ /// Executes the backward pass on a fully connected layer to calculate the gradient with respect to the biases
178
+ /// </summary>
179
+ /// <param name="dy">The layer output error delta</param>
180
+ /// <param name="db">The resulting biases gradient <see cref="Tensor"/></param>
139
181
public static unsafe void FullyConnectedBackwardBias ( in Tensor dy , in Tensor db )
140
182
{
141
183
// Preliminary checks and declarations
142
- if ( ! db . MatchShape ( 1 , dy . Length ) ) throw new ArgumentException ( "The output tensor doesn't have the right shape" , nameof ( db ) ) ;
184
+ if ( ! db . MatchShape ( 1 , dy . Length ) ) throw new ArgumentException ( "Invalid result tensor shape" , nameof ( db ) ) ;
143
185
int
144
186
n = dy . Entities ,
145
187
l = dy . Length ;
@@ -150,8 +192,8 @@ void Kernel(int j)
150
192
{
151
193
float sum = 0 ;
152
194
for ( int i = 0 ; i < n ; i ++ )
153
- sum += pdb [ i * l + j ] ;
154
- pdy [ j ] = sum ;
195
+ sum += pdy [ i * l + j ] ;
196
+ pdb [ j ] = sum ;
155
197
}
156
198
Parallel . For ( 0 , l , Kernel ) . AssertCompleted ( ) ;
157
199
}
0 commit comments