@@ -139,7 +139,7 @@ public static unsafe void ConvolutionBackwardData(
139
139
if ( imgSize < kSize ) throw new ArgumentException ( "Each subdivided tensor must at least have the size of the kernels" ) ;
140
140
if ( dyInfo . Channels != nKernels ) throw new ArgumentException ( "The source depth must be equal to the number of kernels" ) ;
141
141
142
- // Traanspose the layer kernels
142
+ // Rotate the layer kernels
143
143
Rotate180 ( w , wInfo . Channels , out Tensor w180 ) ;
144
144
145
145
/* ============================
@@ -212,11 +212,12 @@ void BackwardsKernel(int index)
212
212
/// <param name="dy">The output error <see cref="Tensor"/></param>
213
213
/// <param name="dyInfo">The output error volume info (depth and 2D slices size)</param>
214
214
/// <param name="dw">The resulting weights gradient</param>
215
+ /// <param name="wInfo">The info on the layer kernels</param>
215
216
/// <exception cref="ArgumentException">The size of one of the input <see cref="Tensor"/> instances isn't valid</exception>
216
217
public static unsafe void ConvolutionBackwardFilter (
217
218
in Tensor x , in TensorInfo xInfo ,
218
219
in Tensor dy , in TensorInfo dyInfo ,
219
- in Tensor dw )
220
+ in Tensor dw , in TensorInfo wInfo )
220
221
{
221
222
// Checks and local parameters
222
223
int
@@ -244,15 +245,19 @@ public static unsafe void ConvolutionBackwardFilter(
244
245
* Kernels: HK*WK*sourceDepth*kernelsDepth (delta(l + 1) used to calculate the 3D gradient for each kernel)
245
246
* Output: sourceDepth*kernelsDepth slices, where each stack of sourceDepth slices is the gradient for the i-th kernel */
246
247
int
247
- hResult = imgHeight - kHeight + 1 , // Size of each image edge after the convolution
248
+ hResult = imgHeight - kHeight + 1 , // Size of each image edge after the convolution
248
249
wResult = imgWidth - kWidth + 1 ,
249
- convolutionOutputSize = hResult * wResult , // Size of each processed image
250
- gradientSize = convolutionOutputSize * xInfo . Channels , // Size of each calculated gradient (one for each original kernel, so for each input delta)
251
- finalWidth = gradientSize * dyInfo . Channels , // Final size of each sample row
252
- iterationsPerSample = xInfo . Channels * kDepth ; // Each sample has its own list of 3D gradients, one for each kernel
250
+ convolutionOutputSize = hResult * wResult , // Size of each processed image
251
+ gradientSize = convolutionOutputSize * xInfo . Channels , // Size of each calculated gradient (one for each original kernel, so for each input delta)
252
+ finalWidth = gradientSize * dyInfo . Channels , // Final size of each sample row
253
+ iterationsPerSample = xInfo . Channels * kDepth ; // Each sample has its own list of 3D gradients, one for each kernel
254
+
255
+ // Rotate the inputs and prepare the temporary tensor
256
+ Rotate180 ( x , xInfo . Channels , out Tensor xt ) ;
257
+ Tensor . New ( x . Entities , finalWidth , out Tensor dwTemp ) ;
253
258
254
259
// Process the valid convolution
255
- float * px = x , pdy = dy , pdw = dw ;
260
+ float * px = xt , pdy = dy , pdw = dwTemp ;
256
261
void GradientKernel ( int index )
257
262
{
258
263
// Calculate the current indexes
@@ -291,7 +296,17 @@ void GradientKernel(int index)
291
296
}
292
297
}
293
298
Parallel . For ( 0 , n * iterationsPerSample , GradientKernel ) . AssertCompleted ( ) ;
294
- throw new NotImplementedException ( "The CPU gradient convolution isn't implemented correctly yet" ) ;
299
+ xt . Free ( ) ;
300
+
301
+ /* ==========================
302
+ * Gradient compression
303
+ * ==========================
304
+ * At this point, the temporary tensor has the series of (p,q) gradients for all the layer
305
+ * kernels, where p is the input depth and q is the kernel index.
306
+ * The final weights gradient is the sum for all the samples in the current training batch */
307
+ dw . Reshape ( 1 , dw . Size , out Tensor wPlane ) ; // The gradient is [q,p]-shaped, flatten to the size of each sample before compressing
308
+ CpuBlas . CompressVertically ( dwTemp , wPlane ) ;
309
+ dwTemp . Free ( ) ;
295
310
}
296
311
297
312
/// <summary>
0 commit comments