Skip to content

Commit 2c7056f

Browse files
authored
Merge pull request #4 from PINTO0309/fix_irversion
Fix to preserve `domain` and `ir_version`
2 parents fdb1cae + 051ab50 commit 2c7056f

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

sne4onnx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from sne4onnx.onnx_network_extraction import extraction, main
22

3-
__version__ = '1.0.11'
3+
__version__ = '1.0.12'

sne4onnx/onnx_network_extraction.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@ def extraction(
119119
if node.domain not in ONNX_STANDARD_DOMAINS
120120
]
121121

122+
# domain, ir_version
123+
domain: str = onnx_graph.domain
124+
ir_version: int = onnx_graph.ir_version
125+
122126
graph = gs.import_onnx(onnx_graph)
123127
graph.cleanup().toposort()
124128

@@ -172,9 +176,9 @@ def extraction(
172176
# Shape Estimation
173177
extracted_graph = None
174178
try:
175-
extracted_graph = onnx.shape_inference.infer_shapes(gs.export_onnx(graph))
179+
extracted_graph = onnx.shape_inference.infer_shapes(gs.export_onnx(graph, do_type_check=False, **{'domain': domain, 'ir_version': ir_version}))
176180
except Exception as e:
177-
extracted_graph = gs.export_onnx(graph)
181+
extracted_graph = gs.export_onnx(graph, do_type_check=False, **{'domain': domain, 'ir_version': ir_version})
178182
if not non_verbose:
179183
print(
180184
f'{Color.YELLOW}WARNING:{Color.RESET} '+

0 commit comments

Comments
 (0)