2
2
using System . Collections . Generic ;
3
3
using System . IO ;
4
4
using System . IO . Compression ;
5
- using System . Linq ;
6
5
using JetBrains . Annotations ;
7
- using NeuralNetworkNET . APIs . Delegates ;
8
6
using NeuralNetworkNET . APIs . Interfaces ;
9
7
using NeuralNetworkNET . APIs . Enums ;
8
+ using NeuralNetworkNET . Cuda . Layers ;
10
9
using NeuralNetworkNET . Extensions ;
11
10
using NeuralNetworkNET . Networks . Implementations ;
12
11
using NeuralNetworkNET . Networks . Implementations . Layers ;
@@ -27,43 +26,37 @@ public static class NeuralNetworkLoader
27
26
/// Tries to deserialize a network from the input file
28
27
/// </summary>
29
28
/// <param name="file">The <see cref="FileInfo"/> instance for the file to load</param>
29
+ /// <param name="preference">The layers deserialization preference</param>
30
30
/// <returns>The deserialized network, or null if the operation fails</returns>
31
31
[ PublicAPI ]
32
32
[ Pure , CanBeNull ]
33
- public static INeuralNetwork TryLoad ( [ NotNull ] FileInfo file )
33
+ public static INeuralNetwork TryLoad ( [ NotNull ] FileInfo file , LayersLoadingPreference preference )
34
34
{
35
35
using ( FileStream stream = file . OpenRead ( ) )
36
- return TryLoad ( stream ) ;
36
+ return TryLoad ( stream , preference ) ;
37
37
}
38
38
39
39
/// <summary>
40
40
/// Tries to deserialize a network from the input <see cref="Stream"/>
41
41
/// </summary>
42
42
/// <param name="stream">The <see cref="Stream"/> instance for the network to load</param>
43
- /// <param name="deserializers ">The list of deserializers to use to load the input network </param>
43
+ /// <param name="preference ">The layers deserialization preference </param>
44
44
/// <returns>The deserialized network, or null if the operation fails</returns>
45
45
[ PublicAPI ]
46
46
[ Pure , CanBeNull ]
47
- public static INeuralNetwork TryLoad ( [ NotNull ] Stream stream , params LayerDeserializer [ ] deserializers )
47
+ public static INeuralNetwork TryLoad ( [ NotNull ] Stream stream , LayersLoadingPreference preference )
48
48
{
49
- if ( deserializers . GroupBy ( f => f ) . Any ( g => g . Count ( ) > 1 ) ) throw new ArgumentException ( "The deserializers list can't contain duplicate entries" , nameof ( deserializers ) ) ;
50
49
try
51
50
{
52
51
List < INetworkLayer > layers = new List < INetworkLayer > ( ) ;
53
52
using ( GZipStream gzip = new GZipStream ( stream , CompressionMode . Decompress ) )
54
53
{
55
54
while ( gzip . TryRead ( out LayerType type ) )
56
55
{
57
- // Process the deserializers in precedence order
56
+ // Deserialization attempt
58
57
INetworkLayer layer = null ;
59
- foreach ( LayerDeserializer deserializer in deserializers )
60
- {
61
- layer = deserializer ( gzip , type ) ;
62
- if ( layer != null ) break ;
63
- }
64
-
65
- // Fallback
66
- if ( layer == null ) layer = DefaultLayersDeserializer ( gzip , type ) ;
58
+ if ( preference == LayersLoadingPreference . Cuda ) layer = CudaDeserialize ( gzip , type ) ;
59
+ if ( layer == null ) layer = CpuDeserialize ( gzip , type ) ;
67
60
if ( layer == null ) return null ;
68
61
69
62
// Add to the queue
@@ -81,9 +74,11 @@ public static INeuralNetwork TryLoad([NotNull] Stream stream, params LayerDeseri
81
74
}
82
75
}
83
76
77
+ #region Deserializers
78
+
84
79
// Default layers deserializer
85
80
[ MustUseReturnValue , CanBeNull ]
86
- private static INetworkLayer DefaultLayersDeserializer ( [ NotNull ] Stream stream , LayerType type )
81
+ private static INetworkLayer CpuDeserialize ( [ NotNull ] Stream stream , LayerType type )
87
82
{
88
83
switch ( type )
89
84
{
@@ -95,5 +90,22 @@ private static INetworkLayer DefaultLayersDeserializer([NotNull] Stream stream,
95
90
default : throw new ArgumentOutOfRangeException ( nameof ( type ) , $ "The { type } layer type is not supported by the default deserializer") ;
96
91
}
97
92
}
93
+
94
+ // Cuda layers deserializer
95
+ [ MustUseReturnValue , CanBeNull ]
96
+ private static INetworkLayer CudaDeserialize ( [ NotNull ] Stream stream , LayerType type )
97
+ {
98
+ switch ( type )
99
+ {
100
+ case LayerType . FullyConnected : return CuDnnFullyConnectedLayer . Deserialize ( stream ) ;
101
+ case LayerType . Convolutional : return CuDnnConvolutionalLayer . Deserialize ( stream ) ;
102
+ case LayerType . Pooling : return CuDnnPoolingLayer . Deserialize ( stream ) ;
103
+ case LayerType . Softmax : return CuDnnSoftmaxLayer . Deserialize ( stream ) ;
104
+ case LayerType . Inception : return CuDnnInceptionLayer . Deserialize ( stream ) ;
105
+ default : return null ;
106
+ }
107
+ }
108
+
109
+ #endregion
98
110
}
99
111
}
0 commit comments