-
Notifications
You must be signed in to change notification settings - Fork 298
Description
Describe the bug
I want to perform a conversion of a PyTorch Faster RCNN model from custom training using the pretrained torchvision model fasterrcnn_mobilenet_v3_large_320_fpn to TFLite format. The steps I intend to take are converting the PyTorch model -> ONNX -> TensorFlow -> TFLite. I have successfully converted the model to ONNX but encountered an error during the conversion to TensorFlow. The error I experienced is shown in the image below."
To Reproduce
Below are the steps i have done so far:
Step 1 convert model to onnx
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
def export_to_onnx():
model_path = '/content/model.pth'
model = torch.load(model_path, map_location=torch.device('cpu')) # Load checkpoint
model.eval()
dummy_input = torch.randn(1, 3, 800, 800) # Dummy input for the model, i used 800x800 image
torch.onnx.export(model, dummy_input, "/content/gdrive/MyDrive/Faster_RCNN/fastercnn-pytorch-training-pipeline/weights/fasterrcnn_test4.onnx", opset_version=11, input_names=['input'], output_names=['output'])
export_to_onnx()
Step 2 simplify the onnx model format
!pip install --upgrade onnx-tf -q
!pip install onnx-simplifier -q
!python -m onnxsim "/content/fasterrcnn_test4.onnx" "/content/fasterrcnn_test4_new.onnx"
Step 3 convert new onnx to tensorflow
import onnx
from onnx_tf.backend import prepare
def convert_onnx_to_tf(onnx_path, tf_path):
# Load the ONNX model
model = onnx.load(onnx_path)
tf_rep = prepare(model) # Prepare the TensorFlow representation
# Export the model to TensorFlow
tf_rep.export_graph(tf_path)
convert_onnx_to_tf("/content/fasterrcnn_test4_new.onnx", "/content/fasterrcnn_test4")
ONNX & PTH model file
ONNX Model:
https://drive.google.com/file/d/1-3Dea11Y58_PUFQyCsZxrxyYEHDlta0R/view?usp=sharing
PTH Model:
https://drive.google.com/file/d/1eJtb-DFh65oCpUB4Vxd3RLa_IrvEzulm/view?usp=sharing
Python, ONNX, ONNX-TF, Tensorflow version
- Python version: 3.10.12
- ONNX version: 1.16.0
- ONNX-TF version: 1.10.0
- Tensorflow version: 2.15.0
Additional context
I am still new to the field of computer vision and have tried looking for other options but haven't found the right one yet. When I used the pretrained model fasterrcnn_mobilenet_v3_large_320_fpn from torchvision without custom training, I was able to successfully convert it using onnx-tf. Has anyone else experienced the same problem? If so, could you kindly share any methods to tackle this issue?