-
-
Notifications
You must be signed in to change notification settings - Fork 598
/
Copy pathbatchsize_clear.py
41 lines (34 loc) · 1.34 KB
/
batchsize_clear.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import onnx
import os
import struct
from argparse import ArgumentParser
def rebatch(infile, outfile, batch_size):
model = onnx.load(infile)
graph = model.graph
# Change batch size in input, output and value_info
for tensor in list(graph.input) + list(graph.value_info) + list(graph.output):
tensor.type.tensor_type.shape.dim[0].dim_param = batch_size
# Set dynamic batch size in reshapes (-1)
for node in graph.node:
if node.op_type != 'Reshape':
continue
for init in graph.initializer:
# node.input[1] is expected to be a reshape
if init.name != node.input[1]:
continue
# Shape is stored as a list of ints
if len(init.int64_data) > 0:
# This overwrites bias nodes' reshape shape but should be fine
init.int64_data[0] = -1
# Shape is stored as bytes
elif len(init.raw_data) > 0:
shape = bytearray(init.raw_data)
struct.pack_into('q', shape, 0, -1)
init.raw_data = bytes(shape)
onnx.save(model, outfile)
if __name__ == '__main__':
parser = ArgumentParser('Replace batch size with \'-1\'')
parser.add_argument('infile')
parser.add_argument('outfile')
args = parser.parse_args()
rebatch(args.infile, args.outfile, '-1')