Skip to content
This repository was archived by the owner on Feb 7, 2023. It is now read-only.

Commit 67491e8

Browse files
committed
Adding ML Model passes
1 parent 4598cd1 commit 67491e8

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

onnx_coreml/converter.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@
2828
ReshapeInitTensorFuser, BNBroadcastedMulFuser, BNBroadcastedAddFuser, \
2929
PixelShuffleFuser, OutputRenamer, AddModelInputsOutputs, \
3030
ConstantsToInitializers, ImageScalerRemover, ShapeOpRemover, ConstantRemover, \
31-
ConstantFillToInitializers, ReshapeTransposeReshape_pattern1, CastOpRemover, DeadCodeElimination
31+
ConstantFillToInitializers, ReshapeTransposeReshape_pattern1, CastOpRemover, \
32+
DeadCodeElimination
33+
34+
# ML model passes
35+
from coremltools.converters.nnssa.coreml.graph_pass.mlmodel_passes import remove_disconnected_constants
3236

3337
from ._error_utils import ErrorHandling
3438
from .graph_viz import plot_graph # type: ignore
@@ -717,6 +721,11 @@ def _add_informative_description(feature, raise_error=True):
717721
if layer.WhichOneof('layer') == 'resizeBilinear' or layer.WhichOneof('layer') == 'cropResize':
718722
raise TypeError('{} not supported with target iOS 11.2 please provide higher target iOS'.format(layer.WhichOneof('layer')))
719723

724+
# Optimize ML Model Spec
725+
ml_model_passes = [remove_disconnected_constants]
726+
for opt in ml_model_passes:
727+
opt(builder.spec)
728+
720729
print("Translation to CoreML spec completed. Now compiling the CoreML model.")
721730
try:
722731
if DEBUG:

tests/test_mlmodel_passes.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import numpy as np
2+
import unittest
3+
import coremltools.models.datatypes as datatypes
4+
from coremltools.models import neural_network as neural_network
5+
from coremltools.converters.nnssa.coreml.graph_pass.mlmodel_passes import remove_disconnected_constants
6+
7+
8+
class MLModelPassesTest(unittest.TestCase):
9+
10+
def test_load_constant_remove(self):
11+
input_features = [('data', datatypes.Array(*(3, 4)))]
12+
output_features = [('out', None)]
13+
builder = neural_network.NeuralNetworkBuilder(input_features, output_features, disable_rank5_shape_mapping=True)
14+
builder.add_activation('relu1', 'RELU', 'data', 'relu1')
15+
builder.add_load_constant_nd('const1', 'c1', constant_value=np.ones((5,)), shape=(5,))
16+
builder.add_activation('relu2', 'RELU', 'relu1', 'out')
17+
builder.add_load_constant_nd('const2', 'c2', constant_value=np.ones((5,)), shape=(5,))
18+
builder.add_load_constant_nd('const3', 'c3', constant_value=np.ones((5,)), shape=(5,))
19+
spec = builder.spec
20+
np.testing.assert_equal(5, len(spec.neuralNetwork.layers))
21+
remove_disconnected_constants(spec)
22+
np.testing.assert_equal(2, len(spec.neuralNetwork.layers))
23+
24+
25+
if __name__ == '__main__':
26+
RUN_ALL_TESTS = True
27+
if RUN_ALL_TESTS:
28+
unittest.main()
29+
else:
30+
suite = unittest.TestSuite()
31+
suite.addTest(MLModelPassesTest('test_load_constant_remove'))
32+
unittest.TextTestRunner().run(suite)

0 commit comments

Comments
 (0)