From d90b6178446df7082c83cfd7cd0abd073e131a1e Mon Sep 17 00:00:00 2001 From: Deyu Huang Date: Wed, 8 Jun 2022 20:29:56 +0800 Subject: [PATCH 1/2] add shape inference for onnx-optimize Signed-off-by: Deyu Huang --- tools/onnx-optimize.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tools/onnx-optimize.py b/tools/onnx-optimize.py index cd263a8e3..e649b780b 100644 --- a/tools/onnx-optimize.py +++ b/tools/onnx-optimize.py @@ -13,7 +13,7 @@ import logging import onnx -from onnx import helper +from onnx import helper, shape_inference from tf2onnx.graph import GraphUtil from tf2onnx import logging, optimizer, constants @@ -46,6 +46,12 @@ def load_graph(fname, target): return g, model_proto +def model_shape_inference(onnx_model_proto): + inferred_model = shape_inference.infer_shapes(onnx_model_proto) + onnx.checker.check_model(inferred_model) + return inferred_model + + def main(): args = get_args() @@ -64,10 +70,12 @@ def main(): model_proto = helper.make_model(onnx_graph, **kwargs) + model_proto_inferred = model_shape_inference(model_proto) + # write onnx graph if args.output: with open(args.output, "wb") as f: - f.write(model_proto.SerializeToString()) + f.write(model_proto_inferred.SerializeToString()) if __name__ == "__main__": From a46e68a2af770738a591c5b4d7550c454b2f0cd0 Mon Sep 17 00:00:00 2001 From: Deyu Huang Date: Thu, 9 Jun 2022 14:34:10 +0800 Subject: [PATCH 2/2] pylint whitespace Signed-off-by: Deyu Huang --- tools/onnx-optimize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/onnx-optimize.py b/tools/onnx-optimize.py index e649b780b..ed7436be8 100644 --- a/tools/onnx-optimize.py +++ b/tools/onnx-optimize.py @@ -71,7 +71,7 @@ def main(): model_proto = helper.make_model(onnx_graph, **kwargs) model_proto_inferred = model_shape_inference(model_proto) - + # write onnx graph if args.output: with open(args.output, "wb") as f: