3
3
import sys
4
4
from argparse import ArgumentParser
5
5
import onnx
6
+ import onnx_graphsurgeon as gs
6
7
from typing import Optional , List
7
8
8
9
class Color :
@@ -37,19 +38,18 @@ def extraction(
37
38
input_onnx_file_path : Optional [str ] = '' ,
38
39
onnx_graph : Optional [onnx .ModelProto ] = None ,
39
40
output_onnx_file_path : Optional [str ] = '' ,
41
+ non_verbose : Optional [bool ] = False ,
40
42
) -> onnx .ModelProto :
41
43
42
44
"""
43
45
Parameters
44
46
----------
45
47
input_op_names: List[str]
46
48
List of OP names to specify for the input layer of the model.\n \
47
- Specify the name of the OP, separated by commas.\n \
48
49
e.g. ['aaa','bbb','ccc']
49
50
50
51
output_op_names: List[str]
51
52
List of OP names to specify for the output layer of the model.\n \
52
- Specify the name of the OP, separated by commas.\n \
53
53
e.g. ['ddd','eee','fff']
54
54
55
55
input_onnx_file_path: Optional[str]
@@ -67,6 +67,10 @@ def extraction(
67
67
If not specified, .onnx is not output.\n \
68
68
Default: ''
69
69
70
+ non_verbose: Optional[bool]
71
+ Do not show all information logs. Only error logs are displayed.\n \
72
+ Default: False
73
+
70
74
Returns
71
75
-------
72
76
extracted_graph: onnx.ModelProto
@@ -80,19 +84,55 @@ def extraction(
80
84
)
81
85
sys .exit (1 )
82
86
87
+ if not input_op_names :
88
+ print (
89
+ f'{ Color .RED } ERROR:{ Color .RESET } ' +
90
+ f'One or more input_op_names must be specified.'
91
+ )
92
+ sys .exit (1 )
93
+
94
+ if not output_op_names :
95
+ print (
96
+ f'{ Color .RED } ERROR:{ Color .RESET } ' +
97
+ f'One or more output_op_names must be specified.'
98
+ )
99
+ sys .exit (1 )
100
+
83
101
# Load
84
102
graph = None
85
103
if not onnx_graph :
86
- graph = onnx .load (input_onnx_file_path )
87
- else :
88
- graph = onnx_graph
89
-
90
- # Extract
91
- extractor = onnx .utils .Extractor (graph )
92
- extracted_graph = extractor .extract_model (
93
- input_op_names ,
94
- output_op_names ,
95
- )
104
+ onnx_graph = onnx .load (input_onnx_file_path )
105
+ onnx_graph = onnx .shape_inference .infer_shapes (onnx_graph )
106
+ graph = gs .import_onnx (onnx_graph )
107
+ graph .cleanup ().toposort ()
108
+
109
+ # Extraction of input OP and output OP
110
+ graph_node_inputs = [graph_nodes for graph_nodes in graph .nodes for graph_nodes_input in graph_nodes .inputs if graph_nodes_input .name in input_op_names ]
111
+ graph_node_outputs = [graph_nodes for graph_nodes in graph .nodes for graph_nodes_output in graph_nodes .outputs if graph_nodes_output .name in output_op_names ]
112
+
113
+ # Init graph INPUT/OUTPUT
114
+ graph .inputs .clear ()
115
+ graph .outputs .clear ()
116
+
117
+ # Update graph INPUT/OUTPUT
118
+ graph .inputs = [graph_node_input for graph_node in graph_node_inputs for graph_node_input in graph_node .inputs if graph_node_input .shape ]
119
+ graph .outputs = [graph_node_output for graph_node in graph_node_outputs for graph_node_output in graph_node .outputs ]
120
+
121
+ # Cleanup
122
+ graph .cleanup ().toposort ()
123
+
124
+ # Shape Estimation
125
+ extracted_graph = None
126
+ try :
127
+ extracted_graph = onnx .shape_inference .infer_shapes (gs .export_onnx (graph ))
128
+ except Exception as e :
129
+ extracted_graph = gs .export_onnx (graph )
130
+ if not non_verbose :
131
+ print (
132
+ f'{ Color .YELLOW } WARNING:{ Color .RESET } ' +
133
+ 'The input shape of the next OP does not match the output shape. ' +
134
+ 'Be sure to open the .onnx file to verify the certainty of the geometry.'
135
+ )
96
136
97
137
# Save
98
138
if output_onnx_file_path :
@@ -112,38 +152,51 @@ def main():
112
152
parser .add_argument (
113
153
'--input_op_names' ,
114
154
type = str ,
155
+ nargs = '+' ,
115
156
required = True ,
116
157
help = "\
117
158
List of OP names to specify for the input layer of the model. \
118
- Specify the name of the OP, separated by commas. \
119
- e.g. --input_op_names aaa,bbb,ccc"
159
+ e.g. --input_op_names aaa bbb ccc"
120
160
)
121
161
parser .add_argument (
122
162
'--output_op_names' ,
123
163
type = str ,
164
+ nargs = '+' ,
124
165
required = True ,
125
166
help = "\
126
167
List of OP names to specify for the output layer of the model. \
127
- Specify the name of the OP, separated by commas. \
128
- e.g. --output_op_names ddd,eee,fff"
168
+ e.g. --output_op_names ddd eee fff"
129
169
)
130
170
parser .add_argument (
131
171
'--output_onnx_file_path' ,
132
172
type = str ,
133
173
default = 'extracted.onnx' ,
134
174
help = 'Output onnx file path. If not specified, extracted.onnx is output.'
135
175
)
176
+ parser .add_argument (
177
+ '--non_verbose' ,
178
+ action = 'store_true' ,
179
+ help = 'Do not show all information logs. Only error logs are displayed.'
180
+ )
136
181
args = parser .parse_args ()
137
182
138
- input_op_names = args .input_op_names .strip (' ,' ).replace (' ' ,'' ).split (',' )
139
- output_op_names = args .output_op_names .strip (' ,' ).replace (' ' ,'' ).split (',' )
183
+ input_onnx_file_path = args .input_onnx_file_path
184
+ input_op_names = args .input_op_names
185
+ output_op_names = args .output_op_names
186
+ output_onnx_file_path = args .output_onnx_file_path
187
+ non_verbose = args .non_verbose
188
+
189
+ # Load
190
+ onnx_graph = onnx .load (input_onnx_file_path )
140
191
141
192
# Model extraction
142
193
extracted_graph = extraction (
143
- input_onnx_file_path = args . input_onnx_file_path ,
194
+ input_onnx_file_path = None ,
144
195
input_op_names = input_op_names ,
145
196
output_op_names = output_op_names ,
146
- output_onnx_file_path = args .output_onnx_file_path ,
197
+ onnx_graph = onnx_graph ,
198
+ output_onnx_file_path = output_onnx_file_path ,
199
+ non_verbose = non_verbose ,
147
200
)
148
201
149
202
0 commit comments