1
+ import numpy as np
2
+ import unittest
3
+ import coremltools .models .datatypes as datatypes
4
+ from coremltools .models import neural_network as neural_network
5
+ from coremltools .converters .nnssa .coreml .graph_pass .mlmodel_passes import remove_disconnected_constants
6
+
7
+
8
+ class MLModelPassesTest (unittest .TestCase ):
9
+
10
+ def test_load_constant_remove (self ):
11
+ input_features = [('data' , datatypes .Array (* (3 , 4 )))]
12
+ output_features = [('out' , None )]
13
+ builder = neural_network .NeuralNetworkBuilder (input_features , output_features , disable_rank5_shape_mapping = True )
14
+ builder .add_activation ('relu1' , 'RELU' , 'data' , 'relu1' )
15
+ builder .add_load_constant_nd ('const1' , 'c1' , constant_value = np .ones ((5 ,)), shape = (5 ,))
16
+ builder .add_activation ('relu2' , 'RELU' , 'relu1' , 'out' )
17
+ builder .add_load_constant_nd ('const2' , 'c2' , constant_value = np .ones ((5 ,)), shape = (5 ,))
18
+ builder .add_load_constant_nd ('const3' , 'c3' , constant_value = np .ones ((5 ,)), shape = (5 ,))
19
+ spec = builder .spec
20
+ np .testing .assert_equal (5 , len (spec .neuralNetwork .layers ))
21
+ remove_disconnected_constants (spec )
22
+ np .testing .assert_equal (2 , len (spec .neuralNetwork .layers ))
23
+
24
+
25
+ if __name__ == '__main__' :
26
+ RUN_ALL_TESTS = True
27
+ if RUN_ALL_TESTS :
28
+ unittest .main ()
29
+ else :
30
+ suite = unittest .TestSuite ()
31
+ suite .addTest (MLModelPassesTest ('test_load_constant_remove' ))
32
+ unittest .TextTestRunner ().run (suite )
0 commit comments