Skip to content

Commit 0f42c29

Browse files
authored
Merge pull request #52 from Sergio0694/feature_cpuDNN-refactoring
Feature cpuDNN refactoring
2 parents 7af1c6f + 95309e0 commit 0f42c29

File tree

18 files changed

+1110
-1286
lines changed

18 files changed

+1110
-1286
lines changed

NeuralNetwork.NET.Cuda/AssemblyInfo.cs

Lines changed: 0 additions & 4 deletions
This file was deleted.

NeuralNetwork.NET.Cuda/NeuralNetwork.NET.Cuda.csproj

Lines changed: 0 additions & 21 deletions
This file was deleted.

NeuralNetwork.NET/APIs/Structs/Tensor.cs

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ namespace NeuralNetworkNET.APIs.Structs
1616
[DebuggerDisplay("Entities: {Entities}, Length: {Length}, Ptr: {Ptr}")]
1717
public readonly struct Tensor
1818
{
19+
#region Fields and parameters
20+
1921
/// <summary>
2022
/// The <see cref="IntPtr"/> value to the allocated memory
2123
/// </summary>
@@ -51,6 +53,8 @@ public bool IsNull
5153
get => Ptr == IntPtr.Zero;
5254
}
5355

56+
#endregion
57+
5458
/// <summary>
5559
/// Gets a null instance
5660
/// </summary>
@@ -109,6 +113,20 @@ public static unsafe void Reshape(float* p, int n, int chw, out Tensor tensor)
109113
tensor = new Tensor(ptr, n, chw);
110114
}
111115

116+
/// <summary>
117+
/// Creates a new instance with the same shape as the input <see cref="Tensor"/>
118+
/// </summary>
119+
/// <param name="mask">The <see cref="Tensor"/> to use to copy the shape</param>
120+
/// <param name="tensor">The output <see cref="Tensor"/></param>
121+
public static void Like(in Tensor mask, out Tensor tensor) => New(mask.Entities, mask.Length, out tensor);
122+
123+
/// <summary>
124+
/// Creates a new instance with the same shape as the input <see cref="Tensor"/> and all the values initializes to 0
125+
/// </summary>
126+
/// <param name="mask">The <see cref="Tensor"/> to use to copy the shape</param>
127+
/// <param name="tensor">The output <see cref="Tensor"/></param>
128+
public static void LikeZeroed(in Tensor mask, out Tensor tensor) => NewZeroed(mask.Entities, mask.Length, out tensor);
129+
112130
/// <summary>
113131
/// Creates a new instance by copying the contents at the given memory location and reshaping it to the desired size
114132
/// </summary>
@@ -155,6 +173,34 @@ public static unsafe void From([NotNull] float[] v, int n, int chw, out Tensor t
155173

156174
#region Tools
157175

176+
/// <summary>
177+
/// Creates a new instance by wrapping the current memory area
178+
/// </summary>
179+
/// <param name="n">The height of the final matrix</param>
180+
/// <param name="chw">The width of the final matrix</param>
181+
/// <param name="tensor">The resulting instance</param>
182+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
183+
public void Reshape(int n, int chw, out Tensor tensor)
184+
{
185+
if (n * chw != Size) throw new ArgumentException("Invalid input resized shape");
186+
tensor = new Tensor(Ptr, n, chw);
187+
}
188+
189+
/// <summary>
190+
/// Checks whether or not the current instance has the same shape of the input <see cref="Tensor"/>
191+
/// </summary>
192+
/// <param name="tensor">The instance to compare</param>
193+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
194+
public bool MatchShape(in Tensor tensor) => Entities == tensor.Entities && Length == tensor.Length;
195+
196+
/// <summary>
197+
/// Checks whether or not the current instance has the same shape as the input arguments
198+
/// </summary>
199+
/// <param name="entities">The expected number of entities</param>
200+
/// <param name="length">The expected length of each entity</param>
201+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
202+
public bool MatchShape(int entities, int length) => Entities == entities && Length == length;
203+
158204
/// <summary>
159205
/// Overwrites the contents of the current matrix with the input matrix
160206
/// </summary>
@@ -181,43 +227,36 @@ internal void Duplicate(out Tensor tensor)
181227
/// <summary>
182228
/// Copies the contents of the unmanaged array to a managed <see cref="Array"/>
183229
/// </summary>
230+
/// <param name="keepAlive">Indicates whether or not to automatically dispose the current instance</param>
184231
[Pure, NotNull]
185-
public float[] ToArray()
232+
public float[] ToArray(bool keepAlive = true)
186233
{
187234
if (Ptr == IntPtr.Zero) return new float[0];
188235
float[] result = new float[Size];
189236
Marshal.Copy(Ptr, result, 0, Size);
237+
if (!keepAlive) Free();
190238
return result;
191239
}
192240

193241
/// <summary>
194242
/// Copies the contents of the unmanaged array to a managed 2D <see cref="Array"/>
195243
/// </summary>
244+
/// <param name="keepAlive">Indicates whether or not to automatically dispose the current instance</param>
196245
[Pure, NotNull]
197-
public unsafe float[,] ToArray2D()
246+
public unsafe float[,] ToArray2D(bool keepAlive = true)
198247
{
199248
if (Ptr == IntPtr.Zero) return new float[0, 0];
200249
float[,] result = new float[Entities, Length];
201250
int size = sizeof(float) * Size;
202251
fixed (float* presult = result)
203252
Buffer.MemoryCopy(this, presult, size, size);
253+
if (!keepAlive) Free();
204254
return result;
205255
}
206256

207257
#endregion
208258

209-
/// <summary>
210-
/// Creates a new instance by wrapping the current memory area
211-
/// </summary>
212-
/// <param name="n">The height of the final matrix</param>
213-
/// <param name="chw">The width of the final matrix</param>
214-
/// <param name="tensor">The resulting instance</param>
215-
[MethodImpl(MethodImplOptions.AggressiveInlining)]
216-
public void Reshape(int n, int chw, out Tensor tensor)
217-
{
218-
if (n * chw != Size) throw new ArgumentException("Invalid input resized shape");
219-
tensor = new Tensor(Ptr, n, chw);
220-
}
259+
#region Memory management
221260

222261
/// <summary>
223262
/// Frees the memory associated with the current instance
@@ -237,8 +276,11 @@ public void TryFree()
237276

238277
// Implicit pointer conversion
239278
[MethodImpl(MethodImplOptions.AggressiveInlining)]
279+
[SuppressMessage("ReSharper", "ImpureMethodCallOnReadonlyValueField")]
240280
public static unsafe implicit operator float*(in Tensor tensor) => (float*)tensor.Ptr.ToPointer();
241281

282+
#endregion
283+
242284
#region Debug
243285

244286
/// <summary>

0 commit comments

Comments
 (0)