Skip to content

Commit 7d7be5b

Browse files
committed
Support for onnx.ModelProto input
1 parent e04cdd0 commit 7d7be5b

File tree

3 files changed

+53
-19
lines changed

3 files changed

+53
-19
lines changed

README.md

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ A very simple tool for situations where optimization with onnx-simplifier would
55

66
# Key concept
77
- [x] If INPUT OP name and OUTPUT OP name are specified, the onnx graph within the range of the specified OP name is extracted and .onnx is generated.
8-
- [ ] Change backend to onnx-graphsurgeon so that onnx.ModelProto can be specified as input.
8+
- [x] Change backend to `onnx.utils.Extractor.extract_model` so that onnx.ModelProto can be specified as input.
99

1010
## 1. Setup
1111
### 1-1. HostPC
@@ -75,7 +75,8 @@ extraction(
7575
input_onnx_file_path: str,
7676
input_op_names: List[str],
7777
output_op_names: List[str],
78-
output_onnx_file_path: Union[str, NoneType] = ''
78+
output_onnx_file_path: Union[str, NoneType] = '',
79+
onnx_graph: Union[onnx.onnx_ml_pb2.ModelProto, NoneType] = None
7980
) -> onnx.onnx_ml_pb2.ModelProto
8081

8182
Parameters
@@ -98,6 +99,11 @@ extraction(
9899
If not specified, .onnx is not output.
99100
Default: ''
100101

102+
onnx_graph: Optional[onnx.ModelProto]
103+
onnx.ModelProto.
104+
Either input_onnx_file_path or onnx_graph must be specified.
105+
onnx_graph If specified, ignore input_onnx_file_path and process onnx_graph.
106+
101107
Returns
102108
-------
103109
extracted_graph: onnx.ModelProto
@@ -114,6 +120,7 @@ $ sne4onnx \
114120
```
115121

116122
## 5. In-script Execution
123+
### 5-1. Use ONNX files
117124
```python
118125
from sne4onnx import extraction
119126

@@ -124,6 +131,17 @@ extracted_graph = extraction(
124131
output_onnx_file_path='output.onnx',
125132
)
126133
```
134+
### 5-2. Use onnx.ModelProto
135+
```python
136+
from sne4onnx import extraction
137+
138+
extracted_graph = extraction(
139+
input_op_names=['aaa', 'bbb', 'ccc'],
140+
output_op_names=['ddd', 'eee', 'fff'],
141+
output_onnx_file_path='output.onnx',
142+
onnx_graph=graph,
143+
)
144+
```
127145

128146
## 6. Samples
129147
### 6-1. Pre-extraction
@@ -147,6 +165,8 @@ $ sne4onnx \
147165

148166
## 7. Reference
149167
1. https://github.com/onnx/onnx/blob/main/docs/PythonAPIOverview.md
150-
2. https://github.com/PINTO0309/snd4onnx
151-
3. https://github.com/PINTO0309/scs4onnx
152-
4. https://github.com/PINTO0309/snc4onnx
168+
2. https://docs.nvidia.com/deeplearning/tensorrt/onnx-graphsurgeon/docs/index.html
169+
3. https://github.com/NVIDIA/TensorRT/tree/main/tools/onnx-graphsurgeon
170+
4. https://github.com/PINTO0309/snd4onnx
171+
5. https://github.com/PINTO0309/scs4onnx
172+
6. https://github.com/PINTO0309/snc4onnx

sne4onnx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from sne4onnx.onnx_network_extraction import extraction, main
22

3-
__version__ = '1.0.3'
3+
__version__ = '1.0.4'

sne4onnx/onnx_network_extraction.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#! /usr/bin/env python
22

3-
import os
3+
import sys
44
from argparse import ArgumentParser
55
import onnx
66
from typing import Optional, List
@@ -36,6 +36,7 @@ def extraction(
3636
input_op_names: List[str],
3737
output_op_names: List[str],
3838
output_onnx_file_path: Optional[str] = '',
39+
onnx_graph: Optional[onnx.ModelProto] = None,
3940
) -> onnx.ModelProto:
4041

4142
"""
@@ -59,28 +60,41 @@ def extraction(
5960
If not specified, .onnx is not output.\n\
6061
Default: ''
6162
63+
onnx_graph: Optional[onnx.ModelProto]
64+
onnx.ModelProto.\n\
65+
Either input_onnx_file_path or onnx_graph must be specified.\n\
66+
onnx_graph If specified, ignore input_onnx_file_path and process onnx_graph.
67+
6268
Returns
6369
-------
6470
extracted_graph: onnx.ModelProto
6571
Extracted onnx ModelProto
6672
"""
6773

68-
tmp_onnx_file = ''
69-
if not output_onnx_file_path:
70-
tmp_onnx_file = 'extracted.onnx'
74+
if not input_onnx_file_path and not onnx_graph:
75+
print(
76+
f'{Color.RED}ERROR:{Color.RESET} '+
77+
f'One of input_onnx_file_path or onnx_graph must be specified.'
78+
)
79+
sys.exit(1)
80+
81+
# Load
82+
graph = None
83+
if not onnx_graph:
84+
graph = onnx.load(input_onnx_file_path)
7185
else:
72-
tmp_onnx_file = output_onnx_file_path
86+
graph = onnx_graph
7387

74-
onnx.utils.extract_model(
75-
input_onnx_file_path,
76-
tmp_onnx_file,
88+
# Extract
89+
extractor = onnx.utils.Extractor(graph)
90+
extracted_graph = extractor.extract_model(
7791
input_op_names,
78-
output_op_names
92+
output_op_names,
7993
)
8094

81-
extracted_graph = onnx.load(tmp_onnx_file)
82-
if not output_onnx_file_path:
83-
os.remove(tmp_onnx_file)
95+
# Save
96+
if output_onnx_file_path:
97+
onnx.save(extracted_graph, output_onnx_file_path)
8498

8599
return extracted_graph
86100

@@ -132,4 +146,4 @@ def main():
132146

133147

134148
if __name__ == '__main__':
135-
main()
149+
main()

0 commit comments

Comments
 (0)