Skip to content

Commit 495356a

Browse files
committed
Deserialization code refactoring
1 parent 7331499 commit 495356a

File tree

6 files changed

+42
-73
lines changed

6 files changed

+42
-73
lines changed

NeuralNetwork.NET/APIs/CuDnnNetworkLayersDeserializer.cs

Lines changed: 0 additions & 39 deletions
This file was deleted.

NeuralNetwork.NET/APIs/Delegates/LayerDeserializer.cs

Lines changed: 0 additions & 15 deletions
This file was deleted.
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
namespace NeuralNetworkNET.APIs.Enums
2+
{
3+
/// <summary>
4+
/// Indicates the preferred type of network layers to serialize, whenever possible
5+
/// </summary>
6+
public enum LayersLoadingPreference
7+
{
8+
Cpu,
9+
Cuda
10+
}
11+
}

NeuralNetwork.NET/APIs/NeuralNetworkLoader.cs

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
using System.Collections.Generic;
33
using System.IO;
44
using System.IO.Compression;
5-
using System.Linq;
65
using JetBrains.Annotations;
7-
using NeuralNetworkNET.APIs.Delegates;
86
using NeuralNetworkNET.APIs.Interfaces;
97
using NeuralNetworkNET.APIs.Enums;
8+
using NeuralNetworkNET.Cuda.Layers;
109
using NeuralNetworkNET.Extensions;
1110
using NeuralNetworkNET.Networks.Implementations;
1211
using NeuralNetworkNET.Networks.Implementations.Layers;
@@ -27,43 +26,37 @@ public static class NeuralNetworkLoader
2726
/// Tries to deserialize a network from the input file
2827
/// </summary>
2928
/// <param name="file">The <see cref="FileInfo"/> instance for the file to load</param>
29+
/// <param name="preference">The layers deserialization preference</param>
3030
/// <returns>The deserialized network, or null if the operation fails</returns>
3131
[PublicAPI]
3232
[Pure, CanBeNull]
33-
public static INeuralNetwork TryLoad([NotNull] FileInfo file)
33+
public static INeuralNetwork TryLoad([NotNull] FileInfo file, LayersLoadingPreference preference)
3434
{
3535
using (FileStream stream = file.OpenRead())
36-
return TryLoad(stream);
36+
return TryLoad(stream, preference);
3737
}
3838

3939
/// <summary>
4040
/// Tries to deserialize a network from the input <see cref="Stream"/>
4141
/// </summary>
4242
/// <param name="stream">The <see cref="Stream"/> instance for the network to load</param>
43-
/// <param name="deserializers">The list of deserializers to use to load the input network</param>
43+
/// <param name="preference">The layers deserialization preference</param>
4444
/// <returns>The deserialized network, or null if the operation fails</returns>
4545
[PublicAPI]
4646
[Pure, CanBeNull]
47-
public static INeuralNetwork TryLoad([NotNull] Stream stream, params LayerDeserializer[] deserializers)
47+
public static INeuralNetwork TryLoad([NotNull] Stream stream, LayersLoadingPreference preference)
4848
{
49-
if (deserializers.GroupBy(f => f).Any(g => g.Count() > 1)) throw new ArgumentException("The deserializers list can't contain duplicate entries", nameof(deserializers));
5049
try
5150
{
5251
List<INetworkLayer> layers = new List<INetworkLayer>();
5352
using (GZipStream gzip = new GZipStream(stream, CompressionMode.Decompress))
5453
{
5554
while (gzip.TryRead(out LayerType type))
5655
{
57-
// Process the deserializers in precedence order
56+
// Deserialization attempt
5857
INetworkLayer layer = null;
59-
foreach (LayerDeserializer deserializer in deserializers)
60-
{
61-
layer = deserializer(gzip, type);
62-
if (layer != null) break;
63-
}
64-
65-
// Fallback
66-
if (layer == null) layer = DefaultLayersDeserializer(gzip, type);
58+
if (preference == LayersLoadingPreference.Cuda) layer = CudaDeserialize(gzip, type);
59+
if (layer == null) layer = CpuDeserialize(gzip, type);
6760
if (layer == null) return null;
6861

6962
// Add to the queue
@@ -81,9 +74,11 @@ public static INeuralNetwork TryLoad([NotNull] Stream stream, params LayerDeseri
8174
}
8275
}
8376

77+
#region Deserializers
78+
8479
// Default layers deserializer
8580
[MustUseReturnValue, CanBeNull]
86-
private static INetworkLayer DefaultLayersDeserializer([NotNull] Stream stream, LayerType type)
81+
private static INetworkLayer CpuDeserialize([NotNull] Stream stream, LayerType type)
8782
{
8883
switch (type)
8984
{
@@ -95,5 +90,22 @@ private static INetworkLayer DefaultLayersDeserializer([NotNull] Stream stream,
9590
default: throw new ArgumentOutOfRangeException(nameof(type), $"The {type} layer type is not supported by the default deserializer");
9691
}
9792
}
93+
94+
// Cuda layers deserializer
95+
[MustUseReturnValue, CanBeNull]
96+
private static INetworkLayer CudaDeserialize([NotNull] Stream stream, LayerType type)
97+
{
98+
switch (type)
99+
{
100+
case LayerType.FullyConnected: return CuDnnFullyConnectedLayer.Deserialize(stream);
101+
case LayerType.Convolutional: return CuDnnConvolutionalLayer.Deserialize(stream);
102+
case LayerType.Pooling: return CuDnnPoolingLayer.Deserialize(stream);
103+
case LayerType.Softmax: return CuDnnSoftmaxLayer.Deserialize(stream);
104+
case LayerType.Inception: return CuDnnInceptionLayer.Deserialize(stream);
105+
default: return null;
106+
}
107+
}
108+
109+
#endregion
98110
}
99111
}

Unit/NeuralNetwork.NET.Cuda.Unit/SerializationTest.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public void NetworkSerialization()
3232
{
3333
network.Save(stream);
3434
stream.Seek(0, SeekOrigin.Begin);
35-
INeuralNetwork copy = NeuralNetworkLoader.TryLoad(stream, CuDnnNetworkLayersDeserializer.Deserializer);
35+
INeuralNetwork copy = NeuralNetworkLoader.TryLoad(stream, LayersLoadingPreference.Cuda);
3636
Assert.IsTrue(network.Equals(copy));
3737
}
3838
}

Unit/NeuralNetwork.NET.Unit/SerializationTest.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ public void NetworkSerialization()
7373
{
7474
network.Save(stream);
7575
stream.Seek(0, SeekOrigin.Begin);
76-
INeuralNetwork copy = NeuralNetworkLoader.TryLoad(stream);
76+
INeuralNetwork copy = NeuralNetworkLoader.TryLoad(stream, LayersLoadingPreference.Cpu);
7777
Assert.IsTrue(network.Equals(copy));
7878
}
7979
}

0 commit comments

Comments
 (0)