Skip to content
This repository was archived by the owner on Feb 7, 2023. It is now read-only.

Commit 0d5c37e

Browse files
authored
Merge pull request #339 from aseemw/dev/add_const
Added squeeze, unsqueeze, const conversions
2 parents bd65864 + 0ac96db commit 0d5c37e

File tree

5 files changed

+79
-18
lines changed

5 files changed

+79
-18
lines changed

onnx_coreml/_graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ def __init__(self,
139139
# data blob name to the op_type that generates it
140140
self.blob_from_op_type = {} # type: Dict[Text, Text]
141141

142+
self.constant_layers_added = {} # type: Dict[Text, bool]
143+
142144
for node_ in nodes:
143145
for input_ in node_.inputs:
144146
if input_ in self.blob_to_op_type:

onnx_coreml/_operators.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
from coremltools.proto import NeuralNetwork_pb2 #type: ignore
1212
from ._error_utils import ErrorHandling
1313

14-
_SEQUENCE_LAYERS_REGISTRY = set(["LSTM"])
15-
1614
def _compare(a, b, encoding="utf8"): #type: (Text, Text, Text) -> bool
1715
if isinstance(a, bytes):
1816
a = a.decode(encoding)
@@ -348,9 +346,6 @@ def _convert_add(builder, node, graph, err): # type: (NeuralNetworkBuilder, Nod
348346
shape_bias=[second_input.shape[0]])
349347
return
350348

351-
if 'broadcast' in node.attrs:
352-
if node.attrs['broadcast'] == 1:
353-
return err.unsupported_op_configuration(builder, node, graph, "Broadcast Add is not supported now")
354349
builder.add_elementwise(
355350
name=node.name,
356351
input_names=node.inputs,
@@ -359,10 +354,6 @@ def _convert_add(builder, node, graph, err): # type: (NeuralNetworkBuilder, Nod
359354
)
360355

361356
def _convert_mul(builder, node, graph, err): # type: (NeuralNetworkBuilder, Node, Graph, ErrorHandling) -> None
362-
if 'broadcast' in node.attrs:
363-
if node.attrs['broadcast'] == 1:
364-
return err.unsupported_op_configuration(builder, node, graph, "Broadcast Multiply is not supported now")
365-
366357
builder.add_elementwise(
367358
name=node.name,
368359
input_names=node.inputs,
@@ -371,10 +362,6 @@ def _convert_mul(builder, node, graph, err): # type: (NeuralNetworkBuilder, Nod
371362
)
372363

373364
def _convert_div(builder, node, graph, err): # type: (NeuralNetworkBuilder, Node, Graph, ErrorHandling) -> None
374-
if 'broadcast' in node.attrs:
375-
if node.attrs['broadcast'] == 1:
376-
return err.unsupported_op_configuration(builder, node, graph, "Broadcast Div is not supported now")
377-
378365
builder.add_unary(name=node.name + '_inverse', #type: ignore
379366
input_name=node.inputs[1],
380367
output_name=node.inputs[1] + '_inverse',
@@ -985,6 +972,34 @@ def _convert_custom(builder, node, graph, err): # type: (NeuralNetworkBuilder, N
985972

986973
err.custom_layer_nodes.append(node)
987974

975+
def _convert_identity(builder, node, graph, err): # type: (NeuralNetworkBuilder, Node, Graph, ErrorHandling) -> None
976+
builder.add_activation(
977+
name=node.name,
978+
non_linearity = 'LINEAR',
979+
input_name=node.inputs[0],
980+
output_name=node.outputs[0],
981+
params=[1.0, 0.0]
982+
)
983+
984+
def _convert_const(builder, node, graph, err): # type: (NeuralNetworkBuilder, Node, Graph, ErrorHandling) -> None
985+
986+
for name, value in node.input_tensors.items():
987+
if name not in graph.constant_layers_added:
988+
shape = value.shape
989+
coreml_shape = [1,1,1]
990+
if len(shape) == 3:
991+
coreml_shape = list(shape)
992+
elif len(shape) == 1:
993+
coreml_shape = [shape[0],1,1]
994+
elif len(shape) == 2:
995+
coreml_shape = [1, shape[0], shape[1]]
996+
else:
997+
return err.unsupported_op_configuration(builder, node, graph, "unable to translate constant array shape to CoreML shape")
998+
builder.add_load_constant(name=name,
999+
output_name=name,
1000+
constant_value=value.flatten(),
1001+
shape=coreml_shape)
1002+
graph.constant_layers_added[name] = True
9881003

9891004

9901005
_ONNX_NODE_REGISTRY = {
@@ -1050,8 +1065,13 @@ def _convert_custom(builder, node, graph, err): # type: (NeuralNetworkBuilder, N
10501065
"ArgMin": _convert_reduce,
10511066
"Clip": _convert_clip,
10521067
"MeanVarianceNormalization": _convert_mvn,
1068+
"Unsqueeze": _convert_identity,
1069+
"Squeeze": _convert_identity
10531070
}
10541071

1072+
_SEQUENCE_LAYERS_REGISTRY = set(["LSTM"])
1073+
1074+
_CONST_INPUT_ALLOWED_LAYERS = set([ "Add", "Sum", "Mul", "Concat", "Max", "Min", "Div", "Reciprocal"])
10551075

10561076
def _get_node_converter_fn(builder, node, err): # type: (NeuralNetworkBuilder, Node, ErrorHandling) -> Callable[[NeuralNetworkBuilder, Node, Graph, ErrorHandling], None]
10571077
"""
@@ -1063,6 +1083,12 @@ def _get_node_converter_fn(builder, node, err): # type: (NeuralNetworkBuilder,
10631083
else:
10641084
return err.unsupported_op(node)
10651085

1086+
def _add_const_inputs_if_required(builder, node, graph, err): # type: (NeuralNetworkBuilder, Node, Graph, ErrorHandling) -> None
1087+
if node.op_type in _CONST_INPUT_ALLOWED_LAYERS:
1088+
if len(node.input_tensors) > 0:
1089+
_convert_const(builder, node, graph, err)
1090+
1091+
10661092
def _convert_node(builder, node, graph, err): # type: (NeuralNetworkBuilder, Node, Graph, ErrorHandling) -> None
10671093
converter_fn = _get_node_converter_fn(builder, node, err)
10681094
return converter_fn(builder, node, graph, err)

onnx_coreml/converter.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from typing import Tuple
1818

19-
from ._operators import _convert_node, _SEQUENCE_LAYERS_REGISTRY, _ONNX_NODE_REGISTRY
19+
from ._operators import _convert_node, _SEQUENCE_LAYERS_REGISTRY, _ONNX_NODE_REGISTRY, _add_const_inputs_if_required
2020
from ._graph import Graph, EdgeInfo, Transformer
2121
from ._transformers import ConvAddFuser, DropoutRemover, \
2222
ReshapeInitTensorFuser, BNBroadcastedMulFuser, BNBroadcastedAddFuser, \
@@ -410,6 +410,7 @@ def convert(model, # type: Union[onnx.ModelProto, Text]
410410

411411
for i, node in enumerate(graph.nodes):
412412
print("%d/%d: Converting Node Type %s" %(i+1, len(graph.nodes), node.op_type))
413+
_add_const_inputs_if_required(builder, node, graph, err)
413414
_convert_node(builder, node, graph, err)
414415

415416
if add_deprocess:
@@ -460,7 +461,12 @@ def convert(model, # type: Union[onnx.ModelProto, Text]
460461
if outputs.name == output_:
461462
builder.spec.description.output[i].shortDescription = 'This output is a sequence'
462463

463-
mlmodel = MLModel(builder.spec)
464+
print("Translation to CoreML spec completed. Now compiling the CoreML model.")
465+
try:
466+
mlmodel = MLModel(builder.spec)
467+
except:
468+
raise ValueError('Compilation failed. Translation to CoreML spec was incorrect.')
469+
464470

465471
# print information about all ops for which custom layers have been added
466472
if len(err.custom_layer_nodes) > 0:

tests/model_test.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _test_torch_model_single_io(torch_model, torch_input_shape, coreml_input_sha
4141

4242
# delete onnx model
4343
if os.path.exists(model_dir):
44-
shutil.rmtree(model_dir)
44+
shutil.rmtree(model_dir)
4545

4646
class OnnxModelTest(unittest.TestCase):
4747

@@ -77,6 +77,35 @@ def forward(self, x):
7777
torch_model.train(False)
7878
_test_torch_model_single_io(torch_model, (1, 3, 100, 100), (3, 100, 100)) # type: ignore
7979

80+
def test_const_initializer1(self): # typr: () -> None
81+
class Net(nn.Module):
82+
def __init__(self):
83+
super(Net, self).__init__()
84+
self.ones = torch.nn.Parameter(torch.ones(1,))
85+
86+
def forward(self, x):
87+
y = x + self.ones
88+
return y
89+
90+
torch_model = Net() # type: ignore
91+
torch_model.train(False)
92+
_test_torch_model_single_io(torch_model, (1, 3), (3,)) # type: ignore
93+
94+
95+
def test_const_initializer2(self): # typr: () -> None
96+
class Net(nn.Module):
97+
def __init__(self):
98+
super(Net, self).__init__()
99+
100+
def forward(self, x):
101+
y = x + torch.nn.Parameter(torch.ones(2, 3))
102+
return y
103+
104+
torch_model = Net() # type: ignore
105+
torch_model.train(False)
106+
_test_torch_model_single_io(torch_model, (1, 2, 3), (1, 2, 3)) # type: ignore
107+
108+
80109

81110
if __name__ == '__main__':
82111
unittest.main()

tests/onnx_backend_node_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,9 @@ def run_node(cls,
127127
backend_test.exclude('test_log_softmax_lastdim_cpu')
128128
backend_test.exclude('test_softmax_functional_dim3_cpu')
129129
backend_test.exclude('test_softmax_lastdim_cpu')
130-
backend_test.exclude('test_squeeze_cpu')
131130
backend_test.exclude('test_sub_bcast_cpu')
132131
backend_test.exclude('test_sub_cpu')
133132
backend_test.exclude('test_sub_example_cpu')
134-
backend_test.exclude('test_unsqueeze_cpu')
135133
backend_test.exclude('test_slice_end_out_of_bounds_cpu')
136134
backend_test.exclude('test_slice_neg_cpu')
137135
backend_test.exclude('test_GLU_cpu')

0 commit comments

Comments
 (0)