4
4
using NeuralNetworkNET . APIs . Enums ;
5
5
using NeuralNetworkNET . APIs . Interfaces ;
6
6
using NeuralNetworkNET . APIs . Structs ;
7
+ using NeuralNetworkNET . cpuDNN ;
7
8
using NeuralNetworkNET . Extensions ;
8
9
using NeuralNetworkNET . Networks . Activations ;
9
10
using NeuralNetworkNET . Networks . Activations . Delegates ;
@@ -80,12 +81,15 @@ public ConvolutionalLayer(
80
81
/// <inheritdoc/>
81
82
public override unsafe void Forward ( in Tensor x , out Tensor z , out Tensor a )
82
83
{
83
- fixed ( float * pw = Weights )
84
+ fixed ( float * pw = Weights , pb = Biases )
84
85
{
85
- Tensor . Reshape ( pw , OutputInfo . Channels , KernelInfo . Size , out Tensor wTensor ) ;
86
- x . ConvoluteForward ( InputInfo , wTensor , KernelInfo , Biases , out z ) ;
87
- if ( ActivationFunctionType == ActivationFunctionType . Identity ) Tensor . From ( z , z . Entities , z . Length , out a ) ;
88
- else z . Activation ( ActivationFunctions . Activation , out a ) ;
86
+ Tensor . Reshape ( pw , OutputInfo . Channels , KernelInfo . Size , out Tensor w ) ;
87
+ Tensor . Reshape ( pb , 1 , Biases . Length , out Tensor b ) ;
88
+ Tensor . New ( x . Entities , OutputInfo . Size , out z ) ;
89
+ CpuDnn . ConvolutionForward ( x , InputInfo , w , KernelInfo , b , z ) ;
90
+ Tensor . New ( z . Entities , z . Length , out a ) ;
91
+ if ( ActivationFunctionType == ActivationFunctionType . Identity ) a . Overwrite ( z ) ;
92
+ else CpuDnn . ActivationForward ( z , ActivationFunctions . Activation , a ) ;
89
93
}
90
94
}
91
95
@@ -107,10 +111,10 @@ public override unsafe void Backpropagate(in Tensor dy, in Tensor z, ActivationF
107
111
public override void ComputeGradient ( in Tensor a , in Tensor delta , out Tensor dJdw , out Tensor dJdb )
108
112
{
109
113
a . Rotate180 ( InputInfo . Channels , out Tensor a180 ) ;
110
- a180 . ConvoluteGradient ( InputInfo , delta , OutputInfo , out Tensor dJdwM ) ;
114
+ ConvolutionExtensions . ConvoluteGradient ( a180 , InputInfo , delta , OutputInfo , out Tensor dJdwM ) ;
111
115
dJdwM . Reshape ( 1 , Weights . Length , out dJdw ) ;
112
116
a180 . Free ( ) ;
113
- delta . CompressVertically ( OutputInfo . Channels , out dJdb ) ;
117
+ ConvolutionExtensions . CompressVertically ( delta , OutputInfo . Channels , out dJdb ) ;
114
118
}
115
119
116
120
#endregion
0 commit comments