Skip to content

Commit 8400226

Browse files
fumihwhtjingrant
authored andcommitted
keep proto when it is an output in graph (#177)
* keep proto when it is an output in graph * remove redundant list
1 parent 01d4f43 commit 8400226

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

onnx_tf/frontend.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -309,11 +309,14 @@ def tensorflow_graph_to_onnx_graph(cls,
309309
make_tensor_value_info(output_name, output_onnx_type,
310310
output.attr["_output_shapes"][i]))
311311

312-
inputs = list(chain.from_iterable(map(lambda p: list(p.input), ops_proto)))
313-
314-
# Remove proto in inputs_proto and consts_proto if proto is not used as input in ONNX
315-
inputs_proto = list(filter(lambda x: x.name in inputs, inputs_proto))
316-
consts_proto = list(filter(lambda x: x.name in inputs, consts_proto))
312+
inputs = list(chain.from_iterable(map(lambda p: p.input, ops_proto)))
313+
outputs = list(map(lambda p: p.name, output_proto))
314+
in_out = inputs + outputs
315+
316+
# Remove proto in inputs_proto and consts_proto
317+
# if proto is not used as input or an output in ONNX
318+
inputs_proto = list(filter(lambda x: x.name in in_out, inputs_proto))
319+
consts_proto = list(filter(lambda x: x.name in in_out, consts_proto))
317320

318321
inputs_proto = cls._data_type_caster(inputs_proto, data_type_cast_map)
319322
consts_proto = cls._data_type_caster(consts_proto, data_type_cast_map)

0 commit comments

Comments
 (0)