Skip to content

Commit 1915755

Browse files
committed
fix inference fail when protobuf size larger than 2GB
Signed-off-by: inisis <desmond.yao@buaa.edu.cn>
1 parent 9d2613d commit 1915755

File tree

1 file changed

+22
-3
lines changed
  • tools/onnx-graphsurgeon/onnx_graphsurgeon/ir

1 file changed

+22
-3
lines changed

tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/graph.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,12 +1245,31 @@ def should_eval_foldable(tensor):
12451245
else:
12461246
names = [t.name for t in graph_clone.outputs]
12471247
try:
1248+
import os
1249+
import tempfile
1250+
import onnx
12481251
import onnxruntime as onnxrt
12491252

1253+
onnx_model = export_onnx(graph_clone, do_type_check=False)
1254+
if onnx_model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF:
1255+
tmp_dir = tempfile.TemporaryDirectory()
1256+
tmp_path = os.path.join(tmp_dir.name, "tmp.onnx")
1257+
location = os.path.basename(tmp_path) + ".data"
1258+
if os.path.exists(location):
1259+
os.remove(location)
1260+
onnx.save(
1261+
onnx_model,
1262+
tmp_path,
1263+
save_as_external_data=True,
1264+
all_tensors_to_one_file=True,
1265+
location=location,
1266+
)
1267+
onnx_model = tmp_path
1268+
else:
1269+
onnx_model = onnx_model.SerializeToString()
1270+
12501271
sess = onnxrt.InferenceSession(
1251-
export_onnx(
1252-
graph_clone, do_type_check=False
1253-
).SerializeToString(),
1272+
onnx_model,
12541273
providers=ORT_PROVIDERS,
12551274
)
12561275
values = sess.run(names, {})

0 commit comments

Comments
 (0)