Skip to content

Commit 42ebecf

Browse files
committed
More APIs refactoring
1 parent 6a7f6e8 commit 42ebecf

File tree

3 files changed

+95
-171
lines changed

3 files changed

+95
-171
lines changed

NeuralNetwork.NET/APIs/Structs/Tensor.cs

Lines changed: 56 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ namespace NeuralNetworkNET.APIs.Structs
1616
[DebuggerDisplay("Entities: {Entities}, Length: {Length}, Ptr: {Ptr}")]
1717
public readonly struct Tensor
1818
{
19+
#region Fields and parameters
20+
1921
/// <summary>
2022
/// The <see cref="IntPtr"/> value to the allocated memory
2123
/// </summary>
@@ -51,6 +53,8 @@ public bool IsNull
5153
get => Ptr == IntPtr.Zero;
5254
}
5355

56+
#endregion
57+
5458
/// <summary>
5559
/// Gets a null instance
5660
/// </summary>
@@ -109,6 +113,20 @@ public static unsafe void Reshape(float* p, int n, int chw, out Tensor tensor)
109113
tensor = new Tensor(ptr, n, chw);
110114
}
111115

116+
/// <summary>
117+
/// Creates a new instance with the same shape as the input <see cref="Tensor"/>
118+
/// </summary>
119+
/// <param name="mask">The <see cref="Tensor"/> to use to copy the shape</param>
120+
/// <param name="tensor">The output <see cref="Tensor"/></param>
121+
public static void Like(in Tensor mask, out Tensor tensor) => New(mask.Entities, mask.Length, out tensor);
122+
123+
/// <summary>
124+
/// Creates a new instance with the same shape as the input <see cref="Tensor"/> and all the values initializes to 0
125+
/// </summary>
126+
/// <param name="mask">The <see cref="Tensor"/> to use to copy the shape</param>
127+
/// <param name="tensor">The output <see cref="Tensor"/></param>
128+
public static void LikeZeroed(in Tensor mask, out Tensor tensor) => NewZeroed(mask.Entities, mask.Length, out tensor);
129+
112130
/// <summary>
113131
/// Creates a new instance by copying the contents at the given memory location and reshaping it to the desired size
114132
/// </summary>
@@ -155,6 +173,34 @@ public static unsafe void From([NotNull] float[] v, int n, int chw, out Tensor t
155173

156174
#region Tools
157175

176+
/// <summary>
177+
/// Creates a new instance by wrapping the current memory area
178+
/// </summary>
179+
/// <param name="n">The height of the final matrix</param>
180+
/// <param name="chw">The width of the final matrix</param>
181+
/// <param name="tensor">The resulting instance</param>
182+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
183+
public void Reshape(int n, int chw, out Tensor tensor)
184+
{
185+
if (n * chw != Size) throw new ArgumentException("Invalid input resized shape");
186+
tensor = new Tensor(Ptr, n, chw);
187+
}
188+
189+
/// <summary>
190+
/// Checks whether or not the current instance has the same shape of the input <see cref="Tensor"/>
191+
/// </summary>
192+
/// <param name="tensor">The instance to compare</param>
193+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
194+
public bool MatchShape(in Tensor tensor) => Entities == tensor.Entities && Length == tensor.Length;
195+
196+
/// <summary>
197+
/// Checks whether or not the current instance has the same shape as the input arguments
198+
/// </summary>
199+
/// <param name="entities">The expected number of entities</param>
200+
/// <param name="length">The expected length of each entity</param>
201+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
202+
public bool MatchShape(int entities, int length) => Entities == entities && Length == length;
203+
158204
/// <summary>
159205
/// Overwrites the contents of the current matrix with the input matrix
160206
/// </summary>
@@ -181,58 +227,36 @@ internal void Duplicate(out Tensor tensor)
181227
/// <summary>
182228
/// Copies the contents of the unmanaged array to a managed <see cref="Array"/>
183229
/// </summary>
230+
/// <param name="keepAlive">Indicates whether or not to automatically dispose the current instance</param>
184231
[Pure, NotNull]
185-
public float[] ToArray()
232+
public float[] ToArray(bool keepAlive = true)
186233
{
187234
if (Ptr == IntPtr.Zero) return new float[0];
188235
float[] result = new float[Size];
189236
Marshal.Copy(Ptr, result, 0, Size);
237+
if (keepAlive) Free();
190238
return result;
191239
}
192240

193241
/// <summary>
194242
/// Copies the contents of the unmanaged array to a managed 2D <see cref="Array"/>
195243
/// </summary>
244+
/// <param name="keepAlive">Indicates whether or not to automatically dispose the current instance</param>
196245
[Pure, NotNull]
197-
public unsafe float[,] ToArray2D()
246+
public unsafe float[,] ToArray2D(bool keepAlive = true)
198247
{
199248
if (Ptr == IntPtr.Zero) return new float[0, 0];
200249
float[,] result = new float[Entities, Length];
201250
int size = sizeof(float) * Size;
202251
fixed (float* presult = result)
203252
Buffer.MemoryCopy(this, presult, size, size);
253+
if (keepAlive) Free();
204254
return result;
205255
}
206256

207257
#endregion
208258

209-
/// <summary>
210-
/// Creates a new instance by wrapping the current memory area
211-
/// </summary>
212-
/// <param name="n">The height of the final matrix</param>
213-
/// <param name="chw">The width of the final matrix</param>
214-
/// <param name="tensor">The resulting instance</param>
215-
[MethodImpl(MethodImplOptions.AggressiveInlining)]
216-
public void Reshape(int n, int chw, out Tensor tensor)
217-
{
218-
if (n * chw != Size) throw new ArgumentException("Invalid input resized shape");
219-
tensor = new Tensor(Ptr, n, chw);
220-
}
221-
222-
/// <summary>
223-
/// Checks whether or not the current instance has the same shape of the input <see cref="Tensor"/>
224-
/// </summary>
225-
/// <param name="tensor">The instance to compare</param>
226-
[MethodImpl(MethodImplOptions.AggressiveInlining)]
227-
public bool MatchShape(in Tensor tensor) => Entities == tensor.Entities && Length == tensor.Length;
228-
229-
/// <summary>
230-
/// Checks whether or not the current instance has the same shape as the input arguments
231-
/// </summary>
232-
/// <param name="entities">The expected number of entities</param>
233-
/// <param name="length">The expected length of each entity</param>
234-
[MethodImpl(MethodImplOptions.AggressiveInlining)]
235-
public bool MatchShape(int entities, int length) => Entities == entities && Length == length;
259+
#region Memory management
236260

237261
/// <summary>
238262
/// Frees the memory associated with the current instance
@@ -252,8 +276,11 @@ public void TryFree()
252276

253277
// Implicit pointer conversion
254278
[MethodImpl(MethodImplOptions.AggressiveInlining)]
279+
[SuppressMessage("ReSharper", "ImpureMethodCallOnReadonlyValueField")]
255280
public static unsafe implicit operator float*(in Tensor tensor) => (float*)tensor.Ptr.ToPointer();
256281

282+
#endregion
283+
257284
#region Debug
258285

259286
/// <summary>

NeuralNetwork.NET/Extensions/MatrixExtensions.cs

Lines changed: 14 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using System;
22
using System.Collections.Generic;
3-
using System.Linq;
43
using System.Runtime.InteropServices;
54
using System.Text;
65
using System.Threading.Tasks;
@@ -53,33 +52,6 @@ void NormalizationKernel(int i)
5352

5453
#region Misc
5554

56-
/// <summary>
57-
/// Calculates the position and the value of the biggest item in a matrix
58-
/// </summary>
59-
/// <param name="m">The input matrix</param>
60-
internal static (int x, int y, float value) Max([NotNull] this float[,] m)
61-
{
62-
// Checks and local variables setup
63-
if (m.Length == 0) throw new ArgumentOutOfRangeException("The input matrix can't be empty");
64-
if (m.Length == 1) return (0, 0, m[0, 0]);
65-
int
66-
h = m.GetLength(0),
67-
w = m.GetLength(1),
68-
x = 0, y = 0;
69-
float max = float.MinValue;
70-
71-
// Find the maximum value and its position
72-
for (int i = 0; i < h; i++)
73-
for (int j = 0; j < w; j++)
74-
{
75-
if (!(m[i, j] > max)) continue;
76-
max = m[i, j];
77-
x = i;
78-
y = j;
79-
}
80-
return (x, y, max);
81-
}
82-
8355
/// <summary>
8456
/// Finds the minimum and maximum value in the input memory area
8557
/// </summary>
@@ -115,33 +87,6 @@ internal static unsafe (float Min, float Max) MinMax(float* p, int length)
11587
return m;
11688
}
11789

118-
/// <summary>
119-
/// Flattens the input volume in a linear array
120-
/// </summary>
121-
/// <param name="volume">The volume to flatten</param>
122-
[PublicAPI]
123-
[Pure, NotNull]
124-
[CollectionAccess(CollectionAccessType.Read)]
125-
public static float[] Flatten([NotNull, ItemNotNull] this IReadOnlyList<float[,]> volume)
126-
{
127-
// Preliminary checks and declarations
128-
if (volume.Count == 0) throw new ArgumentOutOfRangeException("The input volume can't be empty");
129-
int
130-
depth = volume.Count,
131-
length = volume[0].Length,
132-
bytes = sizeof(float) * length;
133-
float[] result = new float[depth * length];
134-
135-
// Execute the copy in parallel
136-
bool loopResult = Parallel.For(0, depth, i =>
137-
{
138-
// Copy the volume data
139-
Buffer.BlockCopy(volume[i], 0, result, bytes * i, bytes);
140-
}).IsCompleted;
141-
if (!loopResult) throw new Exception("Error while runnig the parallel loop");
142-
return result;
143-
}
144-
14590
/// <summary>
14691
/// Merges the input samples into a matrix dataset
14792
/// </summary>
@@ -167,42 +112,12 @@ public static (float[,], float[,]) MergeRows([NotNull] this IReadOnlyList<(float
167112
return (x, y);
168113
}
169114

170-
/// <summary>
171-
/// Merges the rows of the input matrices into a single matrix
172-
/// </summary>
173-
/// <param name="blocks">The matrices to merge</param>
174-
[PublicAPI]
175-
[Pure, NotNull]
176-
[CollectionAccess(CollectionAccessType.Read)]
177-
public static float[,] MergeRows([NotNull, ItemNotNull] this IReadOnlyList<float[,]> blocks)
178-
{
179-
// Preliminary checks and declarations
180-
if (blocks.Count == 0) throw new ArgumentOutOfRangeException("The blocks list can't be empty");
181-
int
182-
h = blocks.Sum(b => b.GetLength(0)),
183-
w = blocks[0].GetLength(1),
184-
rowBytes = sizeof(float) * w;
185-
float[,] result = new float[h, w];
186-
187-
// Execute the copy in parallel
188-
int position = 0;
189-
for (int i = 0; i < blocks.Count; i++)
190-
{
191-
float[,] next = blocks[i];
192-
if (next.GetLength(1) != w) throw new ArgumentOutOfRangeException("The blocks must all have the same width");
193-
int rows = next.GetLength(0);
194-
Buffer.BlockCopy(next, 0, result, rowBytes * position, rowBytes * rows);
195-
position += rows;
196-
}
197-
return result;
198-
}
199-
200115
#endregion
201116

202117
#region Argmax
203118

204119
/// <summary>
205-
/// Returns the index of the maximum value in the input vector
120+
/// Returns the index of the maximum value in the input memory area
206121
/// </summary>
207122
/// <param name="p">A pointer to the buffer to read</param>
208123
/// <param name="length">The length of the buffer to consider</param>
@@ -223,10 +138,19 @@ internal static unsafe int Argmax(float* p, int length)
223138
return index;
224139
}
225140

141+
/// <summary>
142+
/// Returns the index of the maximum value in the input tensor
143+
/// </summary>
144+
/// <param name="tensor">The input <see cref="Tensor"/> to read from</param>
145+
[PublicAPI]
146+
[CollectionAccess(CollectionAccessType.Read)]
147+
public static unsafe int Argmax(in this Tensor tensor) => Argmax(tensor, tensor.Size);
148+
226149
/// <summary>
227150
/// Returns the index of the maximum value in the input vector
228151
/// </summary>
229152
/// <param name="v">The input vector to read from</param>
153+
[PublicAPI]
230154
[CollectionAccess(CollectionAccessType.Read)]
231155
public static unsafe int Argmax([NotNull] this float[] v)
232156
{
@@ -237,6 +161,7 @@ public static unsafe int Argmax([NotNull] this float[] v)
237161
/// Returns the index of the maximum value in the input matrix
238162
/// </summary>
239163
/// <param name="m">The input matrix to read from</param>
164+
[PublicAPI]
240165
[CollectionAccess(CollectionAccessType.Read)]
241166
public static unsafe int Argmax([NotNull] this float[,] m)
242167
{
@@ -295,7 +220,7 @@ internal static unsafe void Fill([NotNull] this Array array, [NotNull] Func<floa
295220
/// Returns a deep copy of the input vector
296221
/// </summary>
297222
/// <param name="v">The vector to clone</param>
298-
/// <remarks>This method avoids the boxing of the <see cref="Array.Clone"/> method, and it is faster thanks to <see cref="Buffer.MemoryCopy"/></remarks>
223+
/// <remarks>This method avoids the boxing of the <see cref="Array.Clone"/> method, and it is faster thanks to the use of the methods in the <see cref="Buffer"/> class</remarks>
299224
[Pure, NotNull]
300225
[CollectionAccess(CollectionAccessType.Read)]
301226
public static unsafe float[] BlockCopy([NotNull] this float[] v)
@@ -309,7 +234,7 @@ public static unsafe float[] BlockCopy([NotNull] this float[] v)
309234

310235
#endregion
311236

312-
#region Content check
237+
#region Content equals
313238

314239
/// <summary>
315240
/// Checks if two matrices have the same size and content
@@ -318,7 +243,7 @@ public static unsafe float[] BlockCopy([NotNull] this float[] v)
318243
/// <param name="o">The second <see cref="Tensor"/> to test</param>
319244
/// <param name="absolute">The relative comparison threshold</param>
320245
/// <param name="relative">The relative comparison threshold</param>
321-
public static unsafe bool ContentEquals(in this Tensor m, in Tensor o,float absolute = 1e-6f, float relative = 1e-6f)
246+
public static unsafe bool ContentEquals(in this Tensor m, in Tensor o, float absolute = 1e-6f, float relative = 1e-6f)
322247
{
323248
if (m.Ptr == IntPtr.Zero && o.Ptr == IntPtr.Zero) return true;
324249
if (m.Ptr == IntPtr.Zero || o.Ptr == IntPtr.Zero) return false;
@@ -374,53 +299,6 @@ public static bool ContentEquals([CanBeNull] this float[] v, [CanBeNull] float[]
374299
return true;
375300
}
376301

377-
// GetUid helper method
378-
private static unsafe int GetUid(float* p, int n)
379-
{
380-
int hash = 17;
381-
unchecked
382-
{
383-
for (int i = 0; i < n; i++)
384-
hash = hash * 23 + p[i].GetHashCode();
385-
return hash;
386-
}
387-
}
388-
389-
/// <summary>
390-
/// Calculates a unique hash code for the target row of the input matrix
391-
/// </summary>
392-
[Pure]
393-
public static unsafe int GetUid([NotNull] this float[,] m, int row)
394-
{
395-
int
396-
w = m.GetLength(1),
397-
offset = row * w;
398-
fixed (float* pm = m)
399-
return GetUid(pm + offset, w);
400-
}
401-
402-
/// <summary>
403-
/// Calculates a unique hash code for the input matrix
404-
/// </summary>
405-
/// <param name="m">The matrix to analyze</param>
406-
[Pure]
407-
public static unsafe int GetUid([NotNull] this float[,] m)
408-
{
409-
fixed (float* pm = m)
410-
return GetUid(pm, m.Length);
411-
}
412-
413-
/// <summary>
414-
/// Calculates a unique hash code for the input vector
415-
/// </summary>
416-
/// <param name="v">The vector to analyze</param>
417-
[Pure]
418-
public static unsafe int GetUid([NotNull] this float[] v)
419-
{
420-
fixed (float* pv = v)
421-
return GetUid(pv, v.Length);
422-
}
423-
424302
#endregion
425303

426304
#region String display

0 commit comments

Comments
 (0)