Skip to content

Commit 5584d31

Browse files
committed
Pooling and convolution forward refactored
1 parent 5a65638 commit 5584d31

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

NeuralNetwork.NET/Networks/Layers/Cpu/ConvolutionalLayer.cs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using NeuralNetworkNET.APIs.Enums;
55
using NeuralNetworkNET.APIs.Interfaces;
66
using NeuralNetworkNET.APIs.Structs;
7+
using NeuralNetworkNET.cpuDNN;
78
using NeuralNetworkNET.Extensions;
89
using NeuralNetworkNET.Networks.Activations;
910
using NeuralNetworkNET.Networks.Activations.Delegates;
@@ -80,12 +81,15 @@ public ConvolutionalLayer(
8081
/// <inheritdoc/>
8182
public override unsafe void Forward(in Tensor x, out Tensor z, out Tensor a)
8283
{
83-
fixed (float* pw = Weights)
84+
fixed (float* pw = Weights, pb = Biases)
8485
{
85-
Tensor.Reshape(pw, OutputInfo.Channels, KernelInfo.Size, out Tensor wTensor);
86-
x.ConvoluteForward(InputInfo, wTensor, KernelInfo, Biases, out z);
87-
if (ActivationFunctionType == ActivationFunctionType.Identity) Tensor.From(z, z.Entities, z.Length, out a);
88-
else z.Activation(ActivationFunctions.Activation, out a);
86+
Tensor.Reshape(pw, OutputInfo.Channels, KernelInfo.Size, out Tensor w);
87+
Tensor.Reshape(pb, 1, Biases.Length, out Tensor b);
88+
Tensor.New(x.Entities, OutputInfo.Size, out z);
89+
CpuDnn.ConvolutionForward(x, InputInfo, w, KernelInfo, b, z);
90+
Tensor.New(z.Entities, z.Length, out a);
91+
if (ActivationFunctionType == ActivationFunctionType.Identity) a.Overwrite(z);
92+
else CpuDnn.ActivationForward(z, ActivationFunctions.Activation, a);
8993
}
9094
}
9195

@@ -107,10 +111,10 @@ public override unsafe void Backpropagate(in Tensor dy, in Tensor z, ActivationF
107111
public override void ComputeGradient(in Tensor a, in Tensor delta, out Tensor dJdw, out Tensor dJdb)
108112
{
109113
a.Rotate180(InputInfo.Channels, out Tensor a180);
110-
a180.ConvoluteGradient(InputInfo, delta, OutputInfo, out Tensor dJdwM);
114+
ConvolutionExtensions.ConvoluteGradient(a180, InputInfo, delta, OutputInfo, out Tensor dJdwM);
111115
dJdwM.Reshape(1, Weights.Length, out dJdw);
112116
a180.Free();
113-
delta.CompressVertically(OutputInfo.Channels, out dJdb);
117+
ConvolutionExtensions.CompressVertically(delta, OutputInfo.Channels, out dJdb);
114118
}
115119

116120
#endregion

NeuralNetwork.NET/Networks/Layers/Cpu/PoolingLayer.cs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using NeuralNetworkNET.APIs.Enums;
55
using NeuralNetworkNET.APIs.Interfaces;
66
using NeuralNetworkNET.APIs.Structs;
7+
using NeuralNetworkNET.cpuDNN;
78
using NeuralNetworkNET.Extensions;
89
using NeuralNetworkNET.Networks.Activations;
910
using NeuralNetworkNET.Networks.Activations.Delegates;
@@ -40,18 +41,20 @@ public PoolingLayer(in TensorInfo input, in PoolingInfo operation, ActivationFun
4041
/// <inheritdoc/>
4142
public override void Forward(in Tensor x, out Tensor z, out Tensor a)
4243
{
43-
x.Pool2x2(InputInfo.Channels, out z);
44-
z.Activation(ActivationFunctions.Activation, out a);
44+
Tensor.New(x.Entities, OutputInfo.Size, out z);
45+
CpuDnn.PoolingForward(x, InputInfo, z);
46+
Tensor.New(z.Entities, z.Length, out a);
47+
CpuDnn.ActivationForward(z, ActivationFunctions.Activation, a);
4548
}
4649

4750
/// <inheritdoc/>
48-
public override void Backpropagate(in Tensor dy, in Tensor z, ActivationFunction activationPrime) => z.UpscalePool2x2(dy, InputInfo.Channels);
51+
public override void Backpropagate(in Tensor dy, in Tensor z, ActivationFunction activationPrime) => CpuDnn.PoolingBackward(z, InputInfo, dy, z);
4952

5053
/// <inheritdoc/>
5154
public override INetworkLayer Clone() => new PoolingLayer(InputInfo, OperationInfo, ActivationFunctionType);
5255

5356
/// <inheritdoc/>
54-
public override void Serialize([NotNull] Stream stream)
57+
public override void Serialize(Stream stream)
5558
{
5659
base.Serialize(stream);
5760
stream.Write(OperationInfo);

0 commit comments

Comments
 (0)