Skip to content

Commit 47ff949

Browse files
committed
Fix #2472, torch.onnx.export_ (with return output) finally removed :(
1 parent 3824443 commit 47ff949

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

onnx_export.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
metavar='N', help='mini-batch size (default: 1)')
4444
parser.add_argument('--img-size', default=None, type=int,
4545
metavar='N', help='Input image dimension, uses model default if empty')
46+
parser.add_argument('--input-size', default=None, nargs=3, type=int, metavar='N',
47+
help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
4648
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
4749
help='Override mean pixel value of dataset')
4850
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
@@ -82,6 +84,14 @@ def main():
8284
if args.reparam:
8385
model = reparameterize_model(model)
8486

87+
if args.input_size is not None:
88+
assert len(args.input_size) == 3, 'input-size should be N H W (channels, height, width)'
89+
input_size = args.input_size
90+
elif args.img_size is not None:
91+
input_size = (3, args.img_size, args.img_size)
92+
else:
93+
input_size = None
94+
8595
onnx_export(
8696
model,
8797
args.output,
@@ -93,7 +103,7 @@ def main():
93103
training=args.training,
94104
verbose=args.verbose,
95105
use_dynamo=args.dynamo,
96-
input_size=(3, args.img_size, args.img_size),
106+
input_size=input_size,
97107
batch_size=args.batch_size,
98108
)
99109

timm/utils/onnx.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def onnx_export(
4343

4444
if example_input is None:
4545
if not input_size:
46-
assert hasattr(model, 'default_cfg')
46+
assert hasattr(model, 'default_cfg'), 'Cannot file model default config, input size must be provided'
4747
input_size = model.default_cfg.get('input_size')
4848
example_input = torch.randn((batch_size,) + input_size, requires_grad=training)
4949

@@ -80,7 +80,7 @@ def onnx_export(
8080
export_output.save(output_file)
8181
torch_out = None
8282
else:
83-
torch_out = torch.onnx._export(
83+
torch_out = torch.onnx.export(
8484
model,
8585
example_input,
8686
output_file,

0 commit comments

Comments
 (0)