Skip to content

Commit 9a83270

Browse files
committed
return onnx.ModelProto
1 parent 3ff789c commit 9a83270

File tree

3 files changed

+66
-34
lines changed

3 files changed

+66
-34
lines changed

README.md

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,27 +35,31 @@ $ cd /workdir
3535
$ sne4onnx -h
3636

3737
usage:
38-
sne4onnx [-h] \
39-
--input_onnx_file_path INPUT_ONNX_FILE_PATH \
40-
--output_onnx_file_path OUTPUT_ONNX_FILE_PATH \
41-
--input_op_names INPUT_OP_NAMES \
38+
sne4onnx [-h]
39+
--input_onnx_file_path INPUT_ONNX_FILE_PATH
40+
--input_op_names INPUT_OP_NAMES
4241
--output_op_names OUTPUT_OP_NAMES
42+
[--output_onnx_file_path OUTPUT_ONNX_FILE_PATH]
4343

4444
optional arguments:
4545
-h, --help
4646
show this help message and exit
47+
4748
--input_onnx_file_path INPUT_ONNX_FILE_PATH
4849
Input onnx file path.
49-
--output_onnx_file_path OUTPUT_ONNX_FILE_PATH
50-
Output onnx file path.
50+
5151
--input_op_names INPUT_OP_NAMES
5252
List of OP names to specify for the input layer of the model.
5353
Specify the name of the OP, separated by commas.
5454
e.g. --input_op_names aaa,bbb,ccc
55+
5556
--output_op_names OUTPUT_OP_NAMES
5657
List of OP names to specify for the output layer of the model.
5758
Specify the name of the OP, separated by commas.
5859
e.g. --output_op_names ddd,eee,fff
60+
61+
--output_onnx_file_path OUTPUT_ONNX_FILE_PATH
62+
Output onnx file path. If not specified, extracted.onnx is output.
5963
```
6064

6165
## 3. In-script Usage
@@ -68,19 +72,16 @@ Help on function extraction in module sne4onnx.onnx_network_extraction:
6872

6973
extraction(
7074
input_onnx_file_path: str,
71-
output_onnx_file_path: str,
7275
input_op_names: List[str],
73-
output_op_names: List[str]
74-
)
76+
output_op_names: List[str],
77+
output_onnx_file_path: Union[str, NoneType] = ''
78+
) -> onnx.onnx_ml_pb2.ModelProto
7579

7680
Parameters
7781
----------
7882
input_onnx_file_path: str
7983
Input onnx file path.
8084

81-
output_onnx_file_path: str
82-
Output onnx file path.
83-
8485
input_op_names: List[str]
8586
List of OP names to specify for the input layer of the model.
8687
Specify the name of the OP, separated by commas.
@@ -90,26 +91,36 @@ extraction(
9091
List of OP names to specify for the output layer of the model.
9192
Specify the name of the OP, separated by commas.
9293
e.g. ['ddd','eee','fff']
94+
95+
output_onnx_file_path: Optional[str]
96+
Output onnx file path.
97+
If not specified, .onnx is not output.
98+
Default: ''
99+
100+
Returns
101+
-------
102+
extracted_graph: onnx.ModelProto
103+
Extracted onnx ModelProto
93104
```
94105

95106
## 4. CLI Execution
96107
```bash
97108
$ sne4onnx \
98109
--input_onnx_file_path input.onnx \
99-
--output_onnx_file_path output.onnx \
100110
--input_op_names aaa,bbb,ccc \
101-
--output_op_names ddd,eee,fff
111+
--output_op_names ddd,eee,fff \
112+
--output_onnx_file_path output.onnx
102113
```
103114

104115
## 5. In-script Execution
105116
```python
106117
from sne4onnx import extraction
107118

108-
extraction(
119+
extracted_graph = extraction(
109120
input_onnx_file_path='input.onnx',
110-
output_onnx_file_path='output.onnx',
111121
input_op_names=['aaa', 'bbb', 'ccc'],
112122
output_op_names=['ddd', 'eee', 'fff'],
123+
output_onnx_file_path='output.onnx',
113124
)
114125
```
115126

@@ -123,9 +134,9 @@ extraction(
123134
```bash
124135
$ sne4onnx \
125136
--input_onnx_file_path hitnet_sf_finalpass_720x1280.onnx \
126-
--output_onnx_file_path hitnet_sf_finalpass_720x960_head.onnx \
127137
--input_op_names 0,1 \
128-
--output_op_names 497,785
138+
--output_op_names 497,785 \
139+
--output_onnx_file_path hitnet_sf_finalpass_720x960_head.onnx
129140
```
130141

131142
### 6-3. Extracted

sne4onnx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from sne4onnx.onnx_network_extraction import extraction, main
22

3-
__version__ = '1.0.2'
3+
__version__ = '1.0.3'

sne4onnx/onnx_network_extraction.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
#! /usr/bin/env python
22

3+
import os
34
from argparse import ArgumentParser
45
import onnx
5-
from typing import List
6+
from typing import Optional, List
67

78
class Color:
89
BLACK = '\033[30m'
@@ -32,20 +33,17 @@ class Color:
3233

3334
def extraction(
3435
input_onnx_file_path: str,
35-
output_onnx_file_path: str,
3636
input_op_names: List[str],
3737
output_op_names: List[str],
38-
):
38+
output_onnx_file_path: Optional[str] = '',
39+
) -> onnx.ModelProto:
3940

4041
"""
4142
Parameters
4243
----------
4344
input_onnx_file_path: str
4445
Input onnx file path.
4546
46-
output_onnx_file_path: str
47-
Output onnx file path.
48-
4947
input_op_names: List[str]
5048
List of OP names to specify for the input layer of the model.\n\
5149
Specify the name of the OP, separated by commas.\n\
@@ -55,15 +53,37 @@ def extraction(
5553
List of OP names to specify for the output layer of the model.\n\
5654
Specify the name of the OP, separated by commas.\n\
5755
e.g. ['ddd','eee','fff']
56+
57+
output_onnx_file_path: Optional[str]
58+
Output onnx file path.\n\
59+
If not specified, .onnx is not output.\n\
60+
Default: ''
61+
62+
Returns
63+
-------
64+
extracted_graph: onnx.ModelProto
65+
Extracted onnx ModelProto
5866
"""
5967

68+
tmp_onnx_file = ''
69+
if not output_onnx_file_path:
70+
tmp_onnx_file = 'extracted.onnx'
71+
else:
72+
tmp_onnx_file = output_onnx_file_path
73+
6074
onnx.utils.extract_model(
6175
input_onnx_file_path,
62-
output_onnx_file_path,
76+
tmp_onnx_file,
6377
input_op_names,
6478
output_op_names
6579
)
6680

81+
extracted_graph = onnx.load(tmp_onnx_file)
82+
if not output_onnx_file_path:
83+
os.remove(tmp_onnx_file)
84+
85+
return extracted_graph
86+
6787

6888
def main():
6989
parser = ArgumentParser()
@@ -73,12 +93,6 @@ def main():
7393
required=True,
7494
help='Input onnx file path.'
7595
)
76-
parser.add_argument(
77-
'--output_onnx_file_path',
78-
type=str,
79-
required=True,
80-
help='Output onnx file path.'
81-
)
8296
parser.add_argument(
8397
'--input_op_names',
8498
type=str,
@@ -97,18 +111,25 @@ def main():
97111
Specify the name of the OP, separated by commas. \
98112
e.g. --output_op_names ddd,eee,fff"
99113
)
114+
parser.add_argument(
115+
'--output_onnx_file_path',
116+
type=str,
117+
default='',
118+
help='Output onnx file path. If not specified, extracted.onnx is output.'
119+
)
100120
args = parser.parse_args()
101121

102122
input_op_names = args.input_op_names.strip(' ,').replace(' ','').split(',')
103123
output_op_names = args.output_op_names.strip(' ,').replace(' ','').split(',')
104124

105125
# Model extraction
106-
extraction(
126+
extracted_graph = extraction(
107127
input_onnx_file_path=args.input_onnx_file_path,
108-
output_onnx_file_path=args.output_onnx_file_path,
109128
input_op_names=input_op_names,
110129
output_op_names=output_op_names,
130+
output_onnx_file_path=args.output_onnx_file_path,
111131
)
112132

133+
113134
if __name__ == '__main__':
114135
main()

0 commit comments

Comments
 (0)