Skip to content

Commit fdb1cae

Browse files
authored
Merge pull request #3 from PINTO0309/support_custom_domain
Support for models with custom domains and elimination of critical bugs
2 parents 0e63f63 + f1d3cbd commit fdb1cae

File tree

2 files changed

+47
-9
lines changed

2 files changed

+47
-9
lines changed

sne4onnx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from sne4onnx.onnx_network_extraction import extraction, main
22

3-
__version__ = '1.0.10'
3+
__version__ = '1.0.11'

sne4onnx/onnx_network_extraction.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ class Color:
3131
BG_DEFAULT = '\033[49m'
3232
RESET = '\033[0m'
3333

34+
ONNX_STANDARD_DOMAINS = [
35+
'ai.onnx',
36+
'ai.onnx.ml',
37+
'',
38+
]
39+
3440

3541
def extraction(
3642
input_op_names: List[str],
@@ -103,20 +109,39 @@ def extraction(
103109
if not onnx_graph:
104110
onnx_graph = onnx.load(input_onnx_file_path)
105111
onnx_graph = onnx.shape_inference.infer_shapes(onnx_graph)
112+
113+
# Acquisition of Node with custom domain
114+
custom_domain_check_onnx_nodes = []
115+
custom_domain_check_onnx_nodes = \
116+
custom_domain_check_onnx_nodes + \
117+
[
118+
node for node in onnx_graph.graph.node \
119+
if node.domain not in ONNX_STANDARD_DOMAINS
120+
]
121+
106122
graph = gs.import_onnx(onnx_graph)
107123
graph.cleanup().toposort()
108124

125+
# Check if Graph contains a custom domain (custom module)
126+
contains_custom_domain = len(
127+
[
128+
domain \
129+
for domain in graph.import_domains \
130+
if domain.domain not in ONNX_STANDARD_DOMAINS
131+
]
132+
) > 0
133+
109134
# Extraction of input OP and output OP
110135
graph_node_inputs = [
111-
graph_nodes \
112-
for graph_nodes in graph.nodes \
113-
for graph_nodes_input in graph_nodes.inputs \
136+
graph_node \
137+
for graph_node in graph.nodes \
138+
for graph_nodes_input in graph_node.inputs \
114139
if graph_nodes_input.name in input_op_names
115140
]
116141
graph_node_outputs = [
117-
graph_nodes \
118-
for graph_nodes in graph.nodes \
119-
for graph_nodes_output in graph_nodes.outputs \
142+
graph_node \
143+
for graph_node in graph.nodes \
144+
for graph_nodes_output in graph_node.outputs \
120145
if graph_nodes_output.name in output_op_names
121146
]
122147

@@ -128,8 +153,10 @@ def extraction(
128153
input_tmp = []
129154
for graph_node in graph_node_inputs:
130155
for graph_node_input in graph_node.inputs:
131-
# if graph_node_input.shape and graph_node_input.name not in [i.name for i in input_tmp]:
132-
if graph_node_input.shape and graph_node_input not in [i for i in input_tmp]:
156+
if graph_node_input.shape \
157+
and graph_node_input not in [i for i in input_tmp] \
158+
and hasattr(graph_node_input, 'name') \
159+
and graph_node_input.name in [i for i in input_op_names]:
133160
input_tmp.append(graph_node_input)
134161
graph.inputs = input_tmp
135162

@@ -155,10 +182,21 @@ def extraction(
155182
'Be sure to open the .onnx file to verify the certainty of the geometry.'
156183
)
157184

185+
## 5. Restore a node's custom domain
186+
if contains_custom_domain:
187+
extracted_graph_nodes = extracted_graph.graph.node
188+
for extracted_graph_node in extracted_graph_nodes:
189+
for custom_domain_check_onnx_node in custom_domain_check_onnx_nodes:
190+
if extracted_graph_node.name == custom_domain_check_onnx_node.name:
191+
extracted_graph_node.domain = custom_domain_check_onnx_node.domain
192+
158193
# Save
159194
if output_onnx_file_path:
160195
onnx.save(extracted_graph, output_onnx_file_path)
161196

197+
if not non_verbose:
198+
print(f'{Color.GREEN}INFO:{Color.RESET} Finish!')
199+
162200
return extracted_graph
163201

164202

0 commit comments

Comments
 (0)