Skip to content

Commit d0d285e

Browse files
authored
Merge pull request #2 from PINTO0309/performance
Significantly faster processing
2 parents 9499494 + a677e1f commit d0d285e

File tree

3 files changed

+90
-32
lines changed

3 files changed

+90
-32
lines changed

README.md

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ usage:
3333
--input_op_names INPUT_OP_NAMES
3434
--output_op_names OUTPUT_OP_NAMES
3535
[--output_onnx_file_path OUTPUT_ONNX_FILE_PATH]
36+
[--non_verbose]
3637

3738
optional arguments:
3839
-h, --help
@@ -43,16 +44,17 @@ optional arguments:
4344

4445
--input_op_names INPUT_OP_NAMES
4546
List of OP names to specify for the input layer of the model.
46-
Specify the name of the OP, separated by commas.
47-
e.g. --input_op_names aaa,bbb,ccc
47+
e.g. --input_op_names aaa bbb ccc
4848

4949
--output_op_names OUTPUT_OP_NAMES
5050
List of OP names to specify for the output layer of the model.
51-
Specify the name of the OP, separated by commas.
52-
e.g. --output_op_names ddd,eee,fff
51+
e.g. --output_op_names ddd eee fff
5352

5453
--output_onnx_file_path OUTPUT_ONNX_FILE_PATH
5554
Output onnx file path. If not specified, extracted.onnx is output.
55+
56+
--non_verbose
57+
Do not show all information logs. Only error logs are displayed.
5658
```
5759

5860
## 3. In-script Usage
@@ -68,19 +70,18 @@ extraction(
6870
output_op_names: List[str],
6971
input_onnx_file_path: Union[str, NoneType] = '',
7072
onnx_graph: Union[onnx.onnx_ml_pb2.ModelProto, NoneType] = None,
71-
output_onnx_file_path: Union[str, NoneType] = ''
73+
output_onnx_file_path: Union[str, NoneType] = '',
74+
non_verbose: Optional[bool] = False
7275
) -> onnx.onnx_ml_pb2.ModelProto
7376

7477
Parameters
7578
----------
7679
input_op_names: List[str]
7780
List of OP names to specify for the input layer of the model.
78-
Specify the name of the OP, separated by commas.
7981
e.g. ['aaa','bbb','ccc']
8082

8183
output_op_names: List[str]
8284
List of OP names to specify for the output layer of the model.
83-
Specify the name of the OP, separated by commas.
8485
e.g. ['ddd','eee','fff']
8586

8687
input_onnx_file_path: Optional[str]
@@ -98,6 +99,10 @@ extraction(
9899
If not specified, .onnx is not output.
99100
Default: ''
100101

102+
non_verbose: Optional[bool]
103+
Do not show all information logs. Only error logs are displayed.
104+
Default: False
105+
101106
Returns
102107
-------
103108
extracted_graph: onnx.ModelProto
@@ -108,8 +113,8 @@ extraction(
108113
```bash
109114
$ sne4onnx \
110115
--input_onnx_file_path input.onnx \
111-
--input_op_names aaa,bbb,ccc \
112-
--output_op_names ddd,eee,fff \
116+
--input_op_names aaa bbb ccc \
117+
--output_op_names ddd eee fff \
113118
--output_onnx_file_path output.onnx
114119
```
115120

@@ -147,8 +152,8 @@ extracted_graph = extraction(
147152
```bash
148153
$ sne4onnx \
149154
--input_onnx_file_path hitnet_sf_finalpass_720x1280.onnx \
150-
--input_op_names 0,1 \
151-
--output_op_names 497,785 \
155+
--input_op_names 0 1 \
156+
--output_op_names 497 785 \
152157
--output_onnx_file_path hitnet_sf_finalpass_720x960_head.onnx
153158
```
154159

sne4onnx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from sne4onnx.onnx_network_extraction import extraction, main
22

3-
__version__ = '1.0.5'
3+
__version__ = '1.0.6'

sne4onnx/onnx_network_extraction.py

Lines changed: 73 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys
44
from argparse import ArgumentParser
55
import onnx
6+
import onnx_graphsurgeon as gs
67
from typing import Optional, List
78

89
class Color:
@@ -37,19 +38,18 @@ def extraction(
3738
input_onnx_file_path: Optional[str] = '',
3839
onnx_graph: Optional[onnx.ModelProto] = None,
3940
output_onnx_file_path: Optional[str] = '',
41+
non_verbose: Optional[bool] = False,
4042
) -> onnx.ModelProto:
4143

4244
"""
4345
Parameters
4446
----------
4547
input_op_names: List[str]
4648
List of OP names to specify for the input layer of the model.\n\
47-
Specify the name of the OP, separated by commas.\n\
4849
e.g. ['aaa','bbb','ccc']
4950
5051
output_op_names: List[str]
5152
List of OP names to specify for the output layer of the model.\n\
52-
Specify the name of the OP, separated by commas.\n\
5353
e.g. ['ddd','eee','fff']
5454
5555
input_onnx_file_path: Optional[str]
@@ -67,6 +67,10 @@ def extraction(
6767
If not specified, .onnx is not output.\n\
6868
Default: ''
6969
70+
non_verbose: Optional[bool]
71+
Do not show all information logs. Only error logs are displayed.\n\
72+
Default: False
73+
7074
Returns
7175
-------
7276
extracted_graph: onnx.ModelProto
@@ -80,19 +84,55 @@ def extraction(
8084
)
8185
sys.exit(1)
8286

87+
if not input_op_names:
88+
print(
89+
f'{Color.RED}ERROR:{Color.RESET} '+
90+
f'One or more input_op_names must be specified.'
91+
)
92+
sys.exit(1)
93+
94+
if not output_op_names:
95+
print(
96+
f'{Color.RED}ERROR:{Color.RESET} '+
97+
f'One or more output_op_names must be specified.'
98+
)
99+
sys.exit(1)
100+
83101
# Load
84102
graph = None
85103
if not onnx_graph:
86-
graph = onnx.load(input_onnx_file_path)
87-
else:
88-
graph = onnx_graph
89-
90-
# Extract
91-
extractor = onnx.utils.Extractor(graph)
92-
extracted_graph = extractor.extract_model(
93-
input_op_names,
94-
output_op_names,
95-
)
104+
onnx_graph = onnx.load(input_onnx_file_path)
105+
onnx_graph = onnx.shape_inference.infer_shapes(onnx_graph)
106+
graph = gs.import_onnx(onnx_graph)
107+
graph.cleanup().toposort()
108+
109+
# Extraction of input OP and output OP
110+
graph_node_inputs = [graph_nodes for graph_nodes in graph.nodes for graph_nodes_input in graph_nodes.inputs if graph_nodes_input.name in input_op_names]
111+
graph_node_outputs = [graph_nodes for graph_nodes in graph.nodes for graph_nodes_output in graph_nodes.outputs if graph_nodes_output.name in output_op_names]
112+
113+
# Init graph INPUT/OUTPUT
114+
graph.inputs.clear()
115+
graph.outputs.clear()
116+
117+
# Update graph INPUT/OUTPUT
118+
graph.inputs = [graph_node_input for graph_node in graph_node_inputs for graph_node_input in graph_node.inputs if graph_node_input.shape]
119+
graph.outputs = [graph_node_output for graph_node in graph_node_outputs for graph_node_output in graph_node.outputs]
120+
121+
# Cleanup
122+
graph.cleanup().toposort()
123+
124+
# Shape Estimation
125+
extracted_graph = None
126+
try:
127+
extracted_graph = onnx.shape_inference.infer_shapes(gs.export_onnx(graph))
128+
except Exception as e:
129+
extracted_graph = gs.export_onnx(graph)
130+
if not non_verbose:
131+
print(
132+
f'{Color.YELLOW}WARNING:{Color.RESET} '+
133+
'The input shape of the next OP does not match the output shape. '+
134+
'Be sure to open the .onnx file to verify the certainty of the geometry.'
135+
)
96136

97137
# Save
98138
if output_onnx_file_path:
@@ -112,38 +152,51 @@ def main():
112152
parser.add_argument(
113153
'--input_op_names',
114154
type=str,
155+
nargs='+',
115156
required=True,
116157
help="\
117158
List of OP names to specify for the input layer of the model. \
118-
Specify the name of the OP, separated by commas. \
119-
e.g. --input_op_names aaa,bbb,ccc"
159+
e.g. --input_op_names aaa bbb ccc"
120160
)
121161
parser.add_argument(
122162
'--output_op_names',
123163
type=str,
164+
nargs='+',
124165
required=True,
125166
help="\
126167
List of OP names to specify for the output layer of the model. \
127-
Specify the name of the OP, separated by commas. \
128-
e.g. --output_op_names ddd,eee,fff"
168+
e.g. --output_op_names ddd eee fff"
129169
)
130170
parser.add_argument(
131171
'--output_onnx_file_path',
132172
type=str,
133173
default='extracted.onnx',
134174
help='Output onnx file path. If not specified, extracted.onnx is output.'
135175
)
176+
parser.add_argument(
177+
'--non_verbose',
178+
action='store_true',
179+
help='Do not show all information logs. Only error logs are displayed.'
180+
)
136181
args = parser.parse_args()
137182

138-
input_op_names = args.input_op_names.strip(' ,').replace(' ','').split(',')
139-
output_op_names = args.output_op_names.strip(' ,').replace(' ','').split(',')
183+
input_onnx_file_path = args.input_onnx_file_path
184+
input_op_names = args.input_op_names
185+
output_op_names = args.output_op_names
186+
output_onnx_file_path = args.output_onnx_file_path
187+
non_verbose = args.non_verbose
188+
189+
# Load
190+
onnx_graph = onnx.load(input_onnx_file_path)
140191

141192
# Model extraction
142193
extracted_graph = extraction(
143-
input_onnx_file_path=args.input_onnx_file_path,
194+
input_onnx_file_path=None,
144195
input_op_names=input_op_names,
145196
output_op_names=output_op_names,
146-
output_onnx_file_path=args.output_onnx_file_path,
197+
onnx_graph=onnx_graph,
198+
output_onnx_file_path=output_onnx_file_path,
199+
non_verbose=non_verbose,
147200
)
148201

149202

0 commit comments

Comments
 (0)