-
Notifications
You must be signed in to change notification settings - Fork 79
Open
Description
Hi,I want to transform STTN pytorch model to onnx format to deploy. Following is my code:
if name == 'main':
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
# 加载原始模型
net = importlib.import_module('model.' + args.model)
model = net.InpaintGenerator().to(device)
model.load_state_dict(torch.load(args.ckpt, map_location=device)['netG'])
model.eval()
# 创建包装器
# wrapped_model = STTN_Wrapper(model).to(device)
wrapped_model = model.to(device)
# 准备符合实际场景的输入样例
batch_size = 1
seq_len = 11 # neighbor_nums
height, width = 240, 432
# 重要:使用真实的输入数据格式
dummy_masked_frames = torch.randn(batch_size, seq_len, 3, height, width, device=device)
dummy_masks = torch.randint(0, 2, (batch_size, seq_len, 1, height, width),
dtype=torch.float32, device=device)
dummy_masks.requires_grad_(True)
# 3. 使用更详细的导出参数
torch.onnx.export(
wrapped_model,
(dummy_masked_frames, dummy_masks),
args.ckpt.replace('.pth', '.onnx'),
export_params=True,
opset_version=12,
do_constant_folding=False,
input_names=['masked_frames', 'masks'],
output_names=['output'],
dynamic_axes=None,
# dynamic_axes={
# 'masked_frames': {0: 'batch_size', 1: 'sequence_length'},
# 'masks': {0: 'batch_size', 1: 'sequence_length'},
# 'output': {0: 'batch_size'}
# },
verbose=True
)
# 4. 验证导出结果
import onnx
onnx_model = onnx.load(args.ckpt.replace('.pth', '.onnx'))
print("Exported model inputs:")
for i, input in enumerate(onnx_model.graph.input):
print(f"{i}. Name: {input.name}, Type: {input.type}")
The model has been transformed to onnx format, but it has only one input,

Metadata
Metadata
Assignees
Labels
No labels