Skip to content

Commit 1b8f4e6

Browse files
committed
Fine-tuning of document and logic indentation
1 parent d0d285e commit 1b8f4e6

File tree

3 files changed

+33
-13
lines changed

3 files changed

+33
-13
lines changed

README.md

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ https://github.com/PINTO0309/simple-onnx-processing-tools
77

88
# Key concept
99
- [x] If INPUT OP name and OUTPUT OP name are specified, the onnx graph within the range of the specified OP name is extracted and .onnx is generated.
10-
- [x] Change backend to `onnx.utils.Extractor.extract_model` so that onnx.ModelProto can be specified as input.
10+
- [x] I do not use `onnx.utils.extractor.extract_model` because it is very slow and I implement my own model separation logic.
1111

1212
## 1. Setup
1313
### 1-1. HostPC
@@ -18,6 +18,7 @@ $ echo export PATH="~/.local/bin:$PATH" >> ~/.bashrc \
1818

1919
### run
2020
$ pip install -U onnx \
21+
&& python3 -m pip install -U onnx_graphsurgeon --index-url https://pypi.ngc.nvidia.com
2122
&& pip install -U sne4onnx
2223
```
2324
### 1-2. Docker
@@ -37,18 +38,18 @@ usage:
3738

3839
optional arguments:
3940
-h, --help
40-
show this help message and exit
41+
show this help message and exit.
4142

4243
--input_onnx_file_path INPUT_ONNX_FILE_PATH
4344
Input onnx file path.
4445

4546
--input_op_names INPUT_OP_NAMES
4647
List of OP names to specify for the input layer of the model.
47-
e.g. --input_op_names aaa bbb ccc
48+
e.g. --input_op_names aaa bbb ccc
4849

4950
--output_op_names OUTPUT_OP_NAMES
5051
List of OP names to specify for the output layer of the model.
51-
e.g. --output_op_names ddd eee fff
52+
e.g. --output_op_names ddd eee fff
5253

5354
--output_onnx_file_path OUTPUT_ONNX_FILE_PATH
5455
Output onnx file path. If not specified, extracted.onnx is output.
@@ -124,8 +125,8 @@ $ sne4onnx \
124125
from sne4onnx import extraction
125126

126127
extracted_graph = extraction(
127-
input_op_names=['aaa', 'bbb', 'ccc'],
128-
output_op_names=['ddd', 'eee', 'fff'],
128+
input_op_names=['aaa','bbb','ccc'],
129+
output_op_names=['ddd','eee','fff'],
129130
input_onnx_file_path='input.onnx',
130131
output_onnx_file_path='output.onnx',
131132
)
@@ -135,8 +136,8 @@ extracted_graph = extraction(
135136
from sne4onnx import extraction
136137

137138
extracted_graph = extraction(
138-
input_op_names=['aaa', 'bbb', 'ccc'],
139-
output_op_names=['ddd', 'eee', 'fff'],
139+
input_op_names=['aaa','bbb','ccc'],
140+
output_op_names=['ddd','eee','fff'],
140141
onnx_graph=graph,
141142
output_onnx_file_path='output.onnx',
142143
)

sne4onnx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from sne4onnx.onnx_network_extraction import extraction, main
22

3-
__version__ = '1.0.6'
3+
__version__ = '1.0.7'

sne4onnx/onnx_network_extraction.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,16 +107,35 @@ def extraction(
107107
graph.cleanup().toposort()
108108

109109
# Extraction of input OP and output OP
110-
graph_node_inputs = [graph_nodes for graph_nodes in graph.nodes for graph_nodes_input in graph_nodes.inputs if graph_nodes_input.name in input_op_names]
111-
graph_node_outputs = [graph_nodes for graph_nodes in graph.nodes for graph_nodes_output in graph_nodes.outputs if graph_nodes_output.name in output_op_names]
110+
graph_node_inputs = [
111+
graph_nodes \
112+
for graph_nodes in graph.nodes \
113+
for graph_nodes_input in graph_nodes.inputs \
114+
if graph_nodes_input.name in input_op_names
115+
]
116+
graph_node_outputs = [
117+
graph_nodes \
118+
for graph_nodes in graph.nodes \
119+
for graph_nodes_output in graph_nodes.outputs \
120+
if graph_nodes_output.name in output_op_names
121+
]
112122

113123
# Init graph INPUT/OUTPUT
114124
graph.inputs.clear()
115125
graph.outputs.clear()
116126

117127
# Update graph INPUT/OUTPUT
118-
graph.inputs = [graph_node_input for graph_node in graph_node_inputs for graph_node_input in graph_node.inputs if graph_node_input.shape]
119-
graph.outputs = [graph_node_output for graph_node in graph_node_outputs for graph_node_output in graph_node.outputs]
128+
graph.inputs = [
129+
graph_node_input \
130+
for graph_node in graph_node_inputs \
131+
for graph_node_input in graph_node.inputs \
132+
if graph_node_input.shape
133+
]
134+
graph.outputs = [
135+
graph_node_output \
136+
for graph_node in graph_node_outputs \
137+
for graph_node_output in graph_node.outputs
138+
]
120139

121140
# Cleanup
122141
graph.cleanup().toposort()

0 commit comments

Comments
 (0)