Skip to content

Commit 7851e01

Browse files
committed
TensorMap class added (for future use)
1 parent 127e241 commit 7851e01

File tree

1 file changed

+104
-0
lines changed

1 file changed

+104
-0
lines changed
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
using JetBrains.Annotations;
2+
using NeuralNetworkNET.APIs.Interfaces;
3+
using NeuralNetworkNET.APIs.Structs;
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Linq;
7+
8+
namespace NeuralNetworkNET.Helpers
9+
{
10+
/// <summary>
11+
/// A simple map that stores references to <see cref="Tensor"/> instances while training or using a network
12+
/// </summary>
13+
internal sealed class TensorMap : IDisposable
14+
{
15+
#region IDisposable
16+
17+
~TensorMap() => Dispose();
18+
19+
/// <inheritdoc/>
20+
void IDisposable.Dispose()
21+
{
22+
GC.SuppressFinalize(this);
23+
Dispose();
24+
}
25+
26+
// Frees the allocated tensors
27+
private void Dispose()
28+
{
29+
foreach (Tensor tensor in new[] { ActivityMap, ActivationMap, DeltaMap }.SelectMany(d => d.Values))
30+
tensor.Free();
31+
ActivityMap.Clear();
32+
ActivationMap.Clear();
33+
DeltaMap.Clear();
34+
}
35+
36+
#endregion
37+
38+
// The Z tensors
39+
[NotNull]
40+
private readonly IDictionary<INetworkLayer, Tensor> ActivityMap = new Dictionary<INetworkLayer, Tensor>();
41+
42+
// The A tensors
43+
[NotNull]
44+
private readonly IDictionary<INetworkLayer, Tensor> ActivationMap = new Dictionary<INetworkLayer, Tensor>();
45+
46+
// The dy tensors
47+
[NotNull]
48+
private readonly IDictionary<INetworkLayer, Tensor> DeltaMap = new Dictionary<INetworkLayer, Tensor>();
49+
50+
/// <summary>
51+
/// Gets or sets a <see cref="Tensor"/> for the given network layer and data type
52+
/// </summary>
53+
/// <param name="layer">The source <see cref="INetworkLayer"/> instance for the target <see cref="Tensor"/></param>
54+
/// <param name="type">The <see cref="TensorType"/> value for the target <see cref="Tensor"/></param>
55+
public Tensor this[INetworkLayer layer, TensorType type]
56+
{
57+
[Pure]
58+
get
59+
{
60+
switch (type)
61+
{
62+
case TensorType.Activity: return ActivityMap[layer];
63+
case TensorType.Activation: return ActivationMap[layer];
64+
case TensorType.Delta: return DeltaMap[layer];
65+
default: throw new ArgumentOutOfRangeException(nameof(type), "Invalid data type requested");
66+
}
67+
}
68+
set
69+
{
70+
IDictionary<INetworkLayer, Tensor> target;
71+
switch (type)
72+
{
73+
case TensorType.Activity: target = ActivityMap; break;
74+
case TensorType.Activation: target = ActivationMap; break;
75+
case TensorType.Delta: target = DeltaMap; break;
76+
default: throw new ArgumentOutOfRangeException(nameof(type), "Invalid data type requested");
77+
}
78+
if (target.TryGetValue(layer, out Tensor old)) old.Free();
79+
target[layer] = value;
80+
}
81+
}
82+
}
83+
84+
/// <summary>
85+
/// Indicates the type of any given <see cref="Tensor"/>
86+
/// </summary>
87+
internal enum TensorType
88+
{
89+
/// <summary>
90+
/// The activity of a network layer, the output before the activation function
91+
/// </summary>
92+
Activity,
93+
94+
/// <summary>
95+
/// The activation of a network layer, the output with the activation function applied to it
96+
/// </summary>
97+
Activation,
98+
99+
/// <summary>
100+
/// The error delta for the outputs of a given network layer
101+
/// </summary>
102+
Delta
103+
}
104+
}

0 commit comments

Comments
 (0)