@@ -16,6 +16,8 @@ namespace NeuralNetworkNET.APIs.Structs
16
16
[ DebuggerDisplay ( "Entities: {Entities}, Length: {Length}, Ptr: {Ptr}" ) ]
17
17
public readonly struct Tensor
18
18
{
19
+ #region Fields and parameters
20
+
19
21
/// <summary>
20
22
/// The <see cref="IntPtr"/> value to the allocated memory
21
23
/// </summary>
@@ -51,6 +53,8 @@ public bool IsNull
51
53
get => Ptr == IntPtr . Zero ;
52
54
}
53
55
56
+ #endregion
57
+
54
58
/// <summary>
55
59
/// Gets a null instance
56
60
/// </summary>
@@ -109,6 +113,20 @@ public static unsafe void Reshape(float* p, int n, int chw, out Tensor tensor)
109
113
tensor = new Tensor ( ptr , n , chw ) ;
110
114
}
111
115
116
+ /// <summary>
117
+ /// Creates a new instance with the same shape as the input <see cref="Tensor"/>
118
+ /// </summary>
119
+ /// <param name="mask">The <see cref="Tensor"/> to use to copy the shape</param>
120
+ /// <param name="tensor">The output <see cref="Tensor"/></param>
121
+ public static void Like ( in Tensor mask , out Tensor tensor ) => New ( mask . Entities , mask . Length , out tensor ) ;
122
+
123
+ /// <summary>
124
+ /// Creates a new instance with the same shape as the input <see cref="Tensor"/> and all the values initializes to 0
125
+ /// </summary>
126
+ /// <param name="mask">The <see cref="Tensor"/> to use to copy the shape</param>
127
+ /// <param name="tensor">The output <see cref="Tensor"/></param>
128
+ public static void LikeZeroed ( in Tensor mask , out Tensor tensor ) => NewZeroed ( mask . Entities , mask . Length , out tensor ) ;
129
+
112
130
/// <summary>
113
131
/// Creates a new instance by copying the contents at the given memory location and reshaping it to the desired size
114
132
/// </summary>
@@ -155,6 +173,34 @@ public static unsafe void From([NotNull] float[] v, int n, int chw, out Tensor t
155
173
156
174
#region Tools
157
175
176
+ /// <summary>
177
+ /// Creates a new instance by wrapping the current memory area
178
+ /// </summary>
179
+ /// <param name="n">The height of the final matrix</param>
180
+ /// <param name="chw">The width of the final matrix</param>
181
+ /// <param name="tensor">The resulting instance</param>
182
+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
183
+ public void Reshape ( int n , int chw , out Tensor tensor )
184
+ {
185
+ if ( n * chw != Size ) throw new ArgumentException ( "Invalid input resized shape" ) ;
186
+ tensor = new Tensor ( Ptr , n , chw ) ;
187
+ }
188
+
189
+ /// <summary>
190
+ /// Checks whether or not the current instance has the same shape of the input <see cref="Tensor"/>
191
+ /// </summary>
192
+ /// <param name="tensor">The instance to compare</param>
193
+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
194
+ public bool MatchShape ( in Tensor tensor ) => Entities == tensor . Entities && Length == tensor . Length ;
195
+
196
+ /// <summary>
197
+ /// Checks whether or not the current instance has the same shape as the input arguments
198
+ /// </summary>
199
+ /// <param name="entities">The expected number of entities</param>
200
+ /// <param name="length">The expected length of each entity</param>
201
+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
202
+ public bool MatchShape ( int entities , int length ) => Entities == entities && Length == length ;
203
+
158
204
/// <summary>
159
205
/// Overwrites the contents of the current matrix with the input matrix
160
206
/// </summary>
@@ -181,43 +227,36 @@ internal void Duplicate(out Tensor tensor)
181
227
/// <summary>
182
228
/// Copies the contents of the unmanaged array to a managed <see cref="Array"/>
183
229
/// </summary>
230
+ /// <param name="keepAlive">Indicates whether or not to automatically dispose the current instance</param>
184
231
[ Pure , NotNull ]
185
- public float [ ] ToArray ( )
232
+ public float [ ] ToArray ( bool keepAlive = true )
186
233
{
187
234
if ( Ptr == IntPtr . Zero ) return new float [ 0 ] ;
188
235
float [ ] result = new float [ Size ] ;
189
236
Marshal . Copy ( Ptr , result , 0 , Size ) ;
237
+ if ( ! keepAlive ) Free ( ) ;
190
238
return result ;
191
239
}
192
240
193
241
/// <summary>
194
242
/// Copies the contents of the unmanaged array to a managed 2D <see cref="Array"/>
195
243
/// </summary>
244
+ /// <param name="keepAlive">Indicates whether or not to automatically dispose the current instance</param>
196
245
[ Pure , NotNull ]
197
- public unsafe float [ , ] ToArray2D ( )
246
+ public unsafe float [ , ] ToArray2D ( bool keepAlive = true )
198
247
{
199
248
if ( Ptr == IntPtr . Zero ) return new float [ 0 , 0 ] ;
200
249
float [ , ] result = new float [ Entities , Length ] ;
201
250
int size = sizeof ( float ) * Size ;
202
251
fixed ( float * presult = result )
203
252
Buffer . MemoryCopy ( this , presult , size , size ) ;
253
+ if ( ! keepAlive ) Free ( ) ;
204
254
return result ;
205
255
}
206
256
207
257
#endregion
208
258
209
- /// <summary>
210
- /// Creates a new instance by wrapping the current memory area
211
- /// </summary>
212
- /// <param name="n">The height of the final matrix</param>
213
- /// <param name="chw">The width of the final matrix</param>
214
- /// <param name="tensor">The resulting instance</param>
215
- [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
216
- public void Reshape ( int n , int chw , out Tensor tensor )
217
- {
218
- if ( n * chw != Size ) throw new ArgumentException ( "Invalid input resized shape" ) ;
219
- tensor = new Tensor ( Ptr , n , chw ) ;
220
- }
259
+ #region Memory management
221
260
222
261
/// <summary>
223
262
/// Frees the memory associated with the current instance
@@ -237,8 +276,11 @@ public void TryFree()
237
276
238
277
// Implicit pointer conversion
239
278
[ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
279
+ [ SuppressMessage ( "ReSharper" , "ImpureMethodCallOnReadonlyValueField" ) ]
240
280
public static unsafe implicit operator float * ( in Tensor tensor ) => ( float * ) tensor . Ptr . ToPointer ( ) ;
241
281
282
+ #endregion
283
+
242
284
#region Debug
243
285
244
286
/// <summary>
0 commit comments