@@ -31,6 +31,12 @@ class Color:
31
31
BG_DEFAULT = '\033 [49m'
32
32
RESET = '\033 [0m'
33
33
34
+ ONNX_STANDARD_DOMAINS = [
35
+ 'ai.onnx' ,
36
+ 'ai.onnx.ml' ,
37
+ '' ,
38
+ ]
39
+
34
40
35
41
def extraction (
36
42
input_op_names : List [str ],
@@ -103,20 +109,39 @@ def extraction(
103
109
if not onnx_graph :
104
110
onnx_graph = onnx .load (input_onnx_file_path )
105
111
onnx_graph = onnx .shape_inference .infer_shapes (onnx_graph )
112
+
113
+ # Acquisition of Node with custom domain
114
+ custom_domain_check_onnx_nodes = []
115
+ custom_domain_check_onnx_nodes = \
116
+ custom_domain_check_onnx_nodes + \
117
+ [
118
+ node for node in onnx_graph .graph .node \
119
+ if node .domain not in ONNX_STANDARD_DOMAINS
120
+ ]
121
+
106
122
graph = gs .import_onnx (onnx_graph )
107
123
graph .cleanup ().toposort ()
108
124
125
+ # Check if Graph contains a custom domain (custom module)
126
+ contains_custom_domain = len (
127
+ [
128
+ domain \
129
+ for domain in graph .import_domains \
130
+ if domain .domain not in ONNX_STANDARD_DOMAINS
131
+ ]
132
+ ) > 0
133
+
109
134
# Extraction of input OP and output OP
110
135
graph_node_inputs = [
111
- graph_nodes \
112
- for graph_nodes in graph .nodes \
113
- for graph_nodes_input in graph_nodes .inputs \
136
+ graph_node \
137
+ for graph_node in graph .nodes \
138
+ for graph_nodes_input in graph_node .inputs \
114
139
if graph_nodes_input .name in input_op_names
115
140
]
116
141
graph_node_outputs = [
117
- graph_nodes \
118
- for graph_nodes in graph .nodes \
119
- for graph_nodes_output in graph_nodes .outputs \
142
+ graph_node \
143
+ for graph_node in graph .nodes \
144
+ for graph_nodes_output in graph_node .outputs \
120
145
if graph_nodes_output .name in output_op_names
121
146
]
122
147
@@ -128,8 +153,10 @@ def extraction(
128
153
input_tmp = []
129
154
for graph_node in graph_node_inputs :
130
155
for graph_node_input in graph_node .inputs :
131
- # if graph_node_input.shape and graph_node_input.name not in [i.name for i in input_tmp]:
132
- if graph_node_input .shape and graph_node_input not in [i for i in input_tmp ]:
156
+ if graph_node_input .shape \
157
+ and graph_node_input not in [i for i in input_tmp ] \
158
+ and hasattr (graph_node_input , 'name' ) \
159
+ and graph_node_input .name in [i for i in input_op_names ]:
133
160
input_tmp .append (graph_node_input )
134
161
graph .inputs = input_tmp
135
162
@@ -155,10 +182,21 @@ def extraction(
155
182
'Be sure to open the .onnx file to verify the certainty of the geometry.'
156
183
)
157
184
185
+ ## 5. Restore a node's custom domain
186
+ if contains_custom_domain :
187
+ extracted_graph_nodes = extracted_graph .graph .node
188
+ for extracted_graph_node in extracted_graph_nodes :
189
+ for custom_domain_check_onnx_node in custom_domain_check_onnx_nodes :
190
+ if extracted_graph_node .name == custom_domain_check_onnx_node .name :
191
+ extracted_graph_node .domain = custom_domain_check_onnx_node .domain
192
+
158
193
# Save
159
194
if output_onnx_file_path :
160
195
onnx .save (extracted_graph , output_onnx_file_path )
161
196
197
+ if not non_verbose :
198
+ print (f'{ Color .GREEN } INFO:{ Color .RESET } Finish!' )
199
+
162
200
return extracted_graph
163
201
164
202
0 commit comments