Skip to content

Commit 16eb4b4

Browse files
authored
Add a new api from_tflite to improve user experience. (#1954)
* Add a new api from_tflite to improve user experience. Signed-off-by: Jay Zhang <jiz@microsoft.com>
1 parent 29b76df commit 16eb4b4

File tree

4 files changed

+99
-0
lines changed

4 files changed

+99
-0
lines changed

README.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,36 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_graph_def(graph_def,
377377
An ONNX model_proto and an external_tensor_storage dict.
378378
```
379379

380+
### from_tflite
381+
```
382+
import tf2onnx
383+
384+
model_proto, external_tensor_storage = tf2onnx.convert.from_tflite(tflite_path,
385+
input_names=None, output_names=None, opset=None, custom_ops=None, custom_op_handlers=None,
386+
custom_rewriter=None, inputs_as_nchw=None, extra_opset=None, shape_override=None, target=None,
387+
large_model=False, output_path=None):
388+
389+
Args:
390+
tflite_path: the tflite model file full path
391+
input_names: list of input names
392+
output_names: list of output names
393+
opset: the opset to be used for the ONNX model, default is the latest
394+
custom_ops: if a model contains ops not recognized by onnx runtime,
395+
you can tag these ops with a custom op domain so that the
396+
runtime can still open the model. Type is a dictionary `{op name: domain}`.
397+
custom_op_handlers: dictionary of custom ops handlers
398+
custom_rewriter: list of custom graph rewriters
399+
inputs_as_nchw: transpose inputs in list from nchw to nhwc
400+
extra_opset: list of extra opset's, for example the opset's used by custom ops
401+
shape_override: dict with inputs that override the shapes given by tensorflow
402+
target: list of workarounds applied to help certain platforms
403+
large_model: use the ONNX external tensor storage format
404+
output_path: save model to output_path
405+
406+
Returns:
407+
An ONNX model_proto and an external_tensor_storage dict.
408+
```
409+
380410
### Creating custom op mappings from python
381411

382412
For complex custom ops that require graph rewrites or input / attribute rewrites using the python interface to insert a custom op will be the easiest way to accomplish the task.
Binary file not shown.

tests/test_api.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,19 @@ def test_graphdef(self):
231231
self.assertTrue(output_names[0] == "pred")
232232
self.assertAllClose([2.1193342], oy[0], rtol=0.1, atol=0.1)
233233

234+
@check_tf_min_version("2.0")
235+
def test_tflite(self):
236+
output_path = os.path.join(self.test_data_directory, "model.onnx")
237+
238+
x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))
239+
model_proto, _ = tf2onnx.convert.from_tflite("tests/models/regression/tflite/test_api_model.tflite",
240+
input_names=['input'], output_names=['output'],
241+
output_path=output_path)
242+
output_names = [n.name for n in model_proto.graph.output]
243+
oy = self.run_onnxruntime(output_path, {"input": x_val}, output_names)
244+
self.assertTrue(output_names[0] == "output")
245+
exp_result = tf.add(x_val, x_val)
246+
self.assertAllClose(exp_result, oy[0], rtol=0.1, atol=0.1)
234247

235248
if __name__ == '__main__':
236249
unittest_main()

tf2onnx/convert.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,5 +636,61 @@ def from_graph_def(graph_def, name=None, input_names=None, output_names=None, op
636636
return model_proto, external_tensor_storage
637637

638638

639+
def from_tflite(tflite_path, input_names=None, output_names=None, opset=None, custom_ops=None, custom_op_handlers=None,
640+
custom_rewriter=None, inputs_as_nchw=None, extra_opset=None, shape_override=None, target=None,
641+
large_model=False, output_path=None):
642+
"""Returns a ONNX model_proto for a tflite model file.
643+
644+
Args:
645+
tflite_path: the tflite model file full path
646+
input_names: list of input names
647+
output_names: list of output names
648+
opset: the opset to be used for the ONNX model, default is the latest
649+
custom_ops: if a model contains ops not recognized by onnx runtime,
650+
you can tag these ops with a custom op domain so that the
651+
runtime can still open the model. Type is a dictionary `{op name: domain}`.
652+
custom_op_handlers: dictionary of custom ops handlers
653+
custom_rewriter: list of custom graph rewriters
654+
inputs_as_nchw: transpose inputs in list from nchw to nhwc
655+
extra_opset: list of extra opset's, for example the opset's used by custom ops
656+
shape_override: dict with inputs that override the shapes given by tensorflow
657+
target: list of workarounds applied to help certain platforms
658+
large_model: use the ONNX external tensor storage format
659+
output_path: save model to output_path
660+
661+
Returns:
662+
An ONNX model_proto and an external_tensor_storage dict.
663+
"""
664+
if not tflite_path:
665+
raise ValueError("tflite_path needs to be provided")
666+
if not input_names:
667+
input_names = []
668+
if not output_names:
669+
output_names = []
670+
671+
with tf.device("/cpu:0"):
672+
model_proto, external_tensor_storage = _convert_common(
673+
None,
674+
tflite_path=tflite_path,
675+
name=os.path.splitext(os.path.basename(tflite_path))[0],
676+
continue_on_error=True,
677+
target=target,
678+
opset=opset,
679+
custom_ops=custom_ops,
680+
custom_op_handlers=custom_op_handlers,
681+
custom_rewriter=custom_rewriter,
682+
extra_opset=extra_opset,
683+
shape_override=shape_override,
684+
input_names=input_names,
685+
output_names=output_names,
686+
inputs_as_nchw=inputs_as_nchw,
687+
large_model=large_model,
688+
tensors_to_rename=None,
689+
initialized_tables=None,
690+
output_path=output_path)
691+
692+
return model_proto, external_tensor_storage
693+
694+
639695
if __name__ == "__main__":
640696
main()

0 commit comments

Comments
 (0)