Skip to content

onnx model export #17

@dadaligoudan

Description

@dadaligoudan

Hi,I want to transform STTN pytorch model to onnx format to deploy. Following is my code:
if name == 'main':
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

# 加载原始模型
net = importlib.import_module('model.' + args.model)
model = net.InpaintGenerator().to(device)
model.load_state_dict(torch.load(args.ckpt, map_location=device)['netG'])
model.eval()

# 创建包装器
# wrapped_model = STTN_Wrapper(model).to(device)
wrapped_model = model.to(device)

# 准备符合实际场景的输入样例
batch_size = 1
seq_len = 11  # neighbor_nums
height, width = 240, 432

# 重要:使用真实的输入数据格式
dummy_masked_frames = torch.randn(batch_size, seq_len, 3, height, width, device=device)
dummy_masks = torch.randint(0, 2, (batch_size, seq_len, 1, height, width),
                            dtype=torch.float32, device=device)
dummy_masks.requires_grad_(True)

# 3. 使用更详细的导出参数
torch.onnx.export(
    wrapped_model,
    (dummy_masked_frames, dummy_masks),
    args.ckpt.replace('.pth', '.onnx'),
    export_params=True,
    opset_version=12,
    do_constant_folding=False,
    input_names=['masked_frames', 'masks'],
    output_names=['output'],
    dynamic_axes=None,
    # dynamic_axes={
    #     'masked_frames': {0: 'batch_size', 1: 'sequence_length'},
    #     'masks': {0: 'batch_size', 1: 'sequence_length'},
    #     'output': {0: 'batch_size'}
    # },
    verbose=True
)

# 4. 验证导出结果
import onnx

onnx_model = onnx.load(args.ckpt.replace('.pth', '.onnx'))
print("Exported model inputs:")
for i, input in enumerate(onnx_model.graph.input):
    print(f"{i}. Name: {input.name}, Type: {input.type}")

The model has been transformed to onnx format, but it has only one input,

Image Could you help me with it, I feel so confused about the 'masks' input missing during onnx model export.Thanks a lot.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions