Skip to content

Fine tuned model to onnx/tensorrt/tflite conversion? #31

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
dsbyprateekg opened this issue Mar 23, 2025 · 15 comments
Open

Fine tuned model to onnx/tensorrt/tflite conversion? #31

dsbyprateekg opened this issue Mar 23, 2025 · 15 comments
Labels
exports Model exports (ONNX, TensorRT, TFLite, etc.)

Comments

@dsbyprateekg
Copy link

Please share the notebook/code to convert the fine-tuned model to onnx/tensort/tflite format and their inference code.

@probicheaux
Copy link
Collaborator

probicheaux commented Mar 23, 2025

onnx:

from rfdetr import RFDETRBase
x = RFDETRBase()
x.export()

then afterward in bash if you want fp16 trt:

trtexec --onnx=output/inference_model.onnx --fp16

we havent tested tflite yet

@probicheaux
Copy link
Collaborator

if anyone wants to share their experience getting it working with tflite that would be cool!

@Artem-N
Copy link

Artem-N commented Mar 24, 2025

if anyone wants to share their experience getting it working with tflite that would be cool!

Hey, i"m trying to run my custom .trt model, but base class "RFDETR" don"t support .trt models, if i correct understand.
Load model some like this -
model = RFDETRBase(num_classes=3, pretrain_weights="model_custom.trt")

and error - Failed to load pretrain weights: invalid load key, 'f'.
Failed to load pretrain weights, re-downloading
Traceback (most recent call last):
File "/home/test/.local/lib/python3.10/site-packages/rfdetr/main.py", line 79, in init
checkpoint = torch.load(args.pretrain_weights, map_location='cpu', weights_only=False)
File "/home/test/.local/lib/python3.10/site-packages/torch/serialization.py", line 1040, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File "/home/test/.local/lib/python3.10/site-packages/torch/serialization.py", line 1262, in _legacy_load
magic_number = pickle_module.load(f, **pickle_load_args)
_pickle.UnpicklingError: invalid load key, 'f'.

@dsbyprateekg
Copy link
Author

onnx:

from rfdetr import RFDETRBase
x = RFDETRBase()
x.export()
then afterward in bash if you want fp16 trt:

trtexec --onnx=output/inference_model.onnx --fp16
we havent tested tflite yet

Please share inference code using trt model.

@dsbyprateekg
Copy link
Author

dsbyprateekg commented Mar 25, 2025

if anyone wants to share their experience getting it working with tflite that would be cool!

I am successfully able to convert to tflite. I have created following PR with inference code-
https://github.com/roboflow/rf-detr/pull/45

@isaacrob-roboflow
Copy link
Collaborator

we don't have built-in infra to serve trt models. we assumed people using trt would have their own infra. can you say more about your use case? are there other detector repos that you have used that provide native trt serving? @Artem-N @dsbyprateekg

@dsbyprateekg
Copy link
Author

we don't have built-in infra to serve trt models. we assumed people using trt would have their own infra. can you say more about your use case? are there other detector repos that you have used that provide native trt serving? @Artem-N @dsbyprateekg

Yes @isaacrob-roboflow . I have been using Yolov7, which provides all steps to convert to onnx/tensorrt and also has provided a notebook for trt inference- https://colab.research.google.com/github/WongKinYiu/yolov7/blob/main/tools/YOLOv7trt.ipynb
Also trt format is widely used for making inference faster in GPU devices.

@Artem-N
Copy link

Artem-N commented Mar 27, 2025

we don't have built-in infra to serve trt models. we assumed people using trt would have their own infra. can you say more about your use case? are there other detector repos that you have used that provide native trt serving? @Artem-N @dsbyprateekg

so far, as an experiment I have used the infrastructure from ultralytics in my orin nx 16gb board

In order to run the rf-detr model on video and run it in real time I have modified the code a little and a`m running it from my machine on the rtx4090 well.

But it is interesting to try to run your model on a jetson board (orin nx 16gb) and for this it is clearly necessary to do this with the model converted to tensorflow, for speed up.

Here code what i use for running:

import cv2
import numpy as np
from PIL import Image
import torch
from rfdetr import RFDETRBase
import supervision as sv
from tqdm import tqdm # Import tqdm for progress bar

custom_labels = ["person", "car", "truck"]

model = RFDETRBase(num_classes=3,
pretrain_weights=r"D:\pycharm_projects\rf-detr\model_trained\checkpoint_best_total.pth")

video_source = r"source.mp4"
cap = cv2.VideoCapture(video_source)

total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

ret, frame = cap.read()
if not ret:
print("Error: Could not read video stream.")
cap.release()
exit(1)

rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(rgb_frame)
resolution = pil_image.size

color = sv.ColorPalette.from_hex([
"#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff",
"#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00"
])

base_text_scale = sv.calculate_optimal_text_scale(resolution_wh=resolution)
text_scale = base_text_scale * 0.5 # Smaller text

base_thickness = sv.calculate_optimal_line_thickness(resolution_wh=resolution)
thickness = max(1, int(base_thickness * 0.5)) # Thinner bounding boxes

bbox_annotator = sv.BoxAnnotator(color=color, thickness=thickness)

label_annotator = sv.LabelAnnotator(
color=color,
text_color=sv.Color.BLACK,
text_scale=text_scale,
smart_position=True,
border_radius=1 # Adjust as needed (no background drawn)
)

cap.set(cv2.CAP_PROP_POS_FRAMES, 0)

frame_skip = 2 # Process every 2nd frame
frame_count = 0
last_annotated_frame = None

progress = tqdm(total=total_frames, desc="Processing Video Frames")

while True:
ret, frame = cap.read()
if not ret:
break
frame_count += 1
progress.update(1) # Update progress bar for each frame

# Skip frames if desired
if frame_count % frame_skip != 0:
    if last_annotated_frame is not None:
        cv2.imshow("Video Detection", last_annotated_frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break
    continue

# Convert frame from BGR to RGB then to PIL image
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(rgb_frame)

# Optionally, resize pil_image to a lower resolution for faster processing
small_size = (pil_image.width // 2, pil_image.height // 2)
small_pil = pil_image.resize(small_size)

# Run detection on the smaller frame (adjust threshold if needed)
detections = model.predict(small_pil, threshold=0.3)

# Use custom labels for your classes
labels = [
    f"{custom_labels[class_id]} {confidence:.2f}"
    for class_id, confidence in zip(detections.class_id, detections.confidence)
]

# Annotate the smaller image then resize back for display
annotated_small = small_pil.copy()
annotated_small = bbox_annotator.annotate(annotated_small, detections)
annotated_small = label_annotator.annotate(annotated_small, detections, labels)
annotated_image = annotated_small.resize((pil_image.width, pil_image.height))

# Convert annotated image back to BGR for OpenCV display
annotated_frame = cv2.cvtColor(np.array(annotated_image), cv2.COLOR_RGB2BGR)
last_annotated_frame = annotated_frame  # Save for skipped frames

cv2.imshow("Video Detection", annotated_frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
    break

progress.close()
cap.release()
cv2.destroyAllWindows()

and when i change here - model = RFDETRBase(num_classes=3,
pretrain_weights="path to custom model") - to my .trt converted model i can`t run it

in general i use ultralytics, but also can use yolov7 architectures.
Thanks!

@SkalskiP SkalskiP added the exports Model exports (ONNX, TensorRT, TFLite, etc.) label Mar 31, 2025
@Ben93kie
Copy link

Ben93kie commented Apr 4, 2025

@dsbyprateekg
Hey, just looking into this. Have you had any success running this as an engine. Interested in Deepstream deployment.

@ZarCS
Copy link

ZarCS commented May 15, 2025

trtexec --onnx=output/inference_model.onnx --fp16

Hi, @probicheaux
I'm working on jetson nano. Should I have to do the above on jetson or It can be don on another system and then move to jetson nano for run?

@isaacrob-roboflow
Copy link
Collaborator

that's the right command! trt engines aren't generally cross-platform compatible so you will likely have to do it on the jetson itself. also, we suggest using jetpack 6.2, which comes with trt 10.3. newer versions of trt come with speedups for the particular architecture we use for our backbone, and we benchmark with trt 10.4

@Willjay90
Copy link

Willjay90 commented May 15, 2025

@Ben93kie I'm able to run it with deepstream nvcr.io/nvidia/deepstream:7.1-gc-triton-devel

Steps

  1. export model to tensorrt
# .pth -> .onnx
from rfdetr import RFDETRBase

resolution = 728
output_dir = f"trt_{resolution}"

model = RFDETRBase(pretrain_weights=f"output/{resolution}_checkpoint_best_ema.pth", resolution=resolution)

model.export(output_dir=f"{output_dir}", infer_dir=None, simplify=True,  backbone_only=False)
# .onnx -> .engine (nvcr.io/nvidia/deepstream:7.1-gc-triton-devel)
trtexec --onnx="inference_model.onnx" \
        --saveEngine="inference_model.engine" \
        --memPoolSize=workspace:4096 --fp16 \
        --useCudaGraph --useSpinWait --warmUp=500 --avgRuns=1000 --duration=10 --verbose
  1. Create custom parser plugin rf_detr_parser.cpp ref
g++ -Wall -shared -fPIC \
  -I /usr/local/cuda/include \
  -I /opt/nvidia/deepstream/deepstream/sources/includes \
  -o nvdsinfer_rf_detr_parser.so rf_detr_parser.cpp
#include <nvdsinfer_custom_impl.h>
#include <cstring>
#include <vector>
#include <iostream>
#include <cmath>
#include <algorithm>

extern "C"
bool NvDsInferRFDETRParser(
    std::vector<NvDsInferLayerInfo> const &outputLayersInfo,
    NvDsInferNetworkInfo const &networkInfo,
    NvDsInferParseDetectionParams const &detectionParams,
    std::vector<NvDsInferObjectDetectionInfo> &objectList)
{
    const NvDsInferLayerInfo *dets_layer = nullptr;
    const NvDsInferLayerInfo *labels_layer = nullptr;

    // Find output layers
    for (auto &layer : outputLayersInfo) {
        if (strcmp(layer.layerName, "dets") == 0) {
            dets_layer = &layer;
        } else if (strcmp(layer.layerName, "labels") == 0) {
            labels_layer = &layer;
        }
    }

    if (!dets_layer || !labels_layer) {
        std::cerr << "Missing output layer: dets or labels\n";
        return false;
    }

    // Get dimensions
    int num_detections = dets_layer->inferDims.d[0];
    int dets_per_detection = dets_layer->inferDims.d[1];
    int labels_per_detection = labels_layer->inferDims.d[1];

    // Validate dimensions
    if (dets_layer->inferDims.d[2] != 0) {
        std::cerr << "Unexpected third dimension in dets: " << dets_layer->inferDims.d[2] << "\n";
        return false;
    }
    if (labels_layer->inferDims.d[2] != 0) {
        std::cerr << "Unexpected third dimension in labels: " << labels_layer->inferDims.d[2] << "\n";
        return false;
    }
    if (dets_per_detection != 4) {
        std::cerr << "Unexpected dets dimension: expected 4, got " << dets_per_detection << "\n";
        return false;
    }

    auto *dets = static_cast<float*>(dets_layer->buffer);
    auto *labels = static_cast<float*>(labels_layer->buffer);

    if (!dets || !labels) {
        std::cerr << "Null buffer for dets or labels\n";
        return false;
    }

    float frame_width = static_cast<float>(networkInfo.width);
    float frame_height = static_cast<float>(networkInfo.height);

    // Structure to hold score, label, and box index for top-k selection
    struct Prediction {
        float score;
        int label;
        int box_index;
        bool operator<(const Prediction &other) const {
            return score > other.score; // Sort descending
        }
    };

    // Compute sigmoid scores and prepare for top-k
    std::vector<Prediction> predictions;
    predictions.reserve(num_detections * labels_per_detection);

    for (int i = 0; i < num_detections; ++i) {
        int offset_label = i * labels_per_detection;
        for (int j = 0; j < labels_per_detection; ++j) {
            float logit = labels[offset_label + j];
            float score = 1.0f / (1.0f + std::exp(-logit)); // Sigmoid
            if (score >= detectionParams.perClassPreclusterThreshold[j]) {
                predictions.push_back({score, j, i});
            }
        }
    }

    // Select top 300 predictions
    const int top_k = std::min(300, static_cast<int>(predictions.size()));
    std::partial_sort(predictions.begin(), predictions.begin() + top_k, predictions.end());

    // Process top-k predictions
    for (int k = 0; k < top_k; ++k) {
        const auto &pred = predictions[k];
        int i = pred.box_index;
        int class_id = pred.label;
        float score = pred.score;

        // Get box coordinates
        int offset_box = i * dets_per_detection;
        float x_c = dets[offset_box + 0];
        float y_c = dets[offset_box + 1];
        float w = dets[offset_box + 2];
        float h = dets[offset_box + 3];

        // Clamp w and h to minimum 0.0
        w = std::max(w, 0.0f);
        h = std::max(h, 0.0f);

        // Convert cx,cy,w,h to x1,y1,x2,y2 (box_cxcywh_to_xyxy)
        float x1 = (x_c - 0.5f * w) * frame_width;
        float y1 = (y_c - 0.5f * h) * frame_height;
        float x2 = (x_c + 0.5f * w) * frame_width;
        float y2 = (y_c + 0.5f * h) * frame_height;

        // Convert to NvDsInferObjectDetectionInfo format
        float left = x1;
        float top = y1;
        float width = x2 - x1;
        float height = y2 - y1;

        // Skip invalid boxes
        if (width <= 0 || height <= 0) {
            continue;
        }

        NvDsInferObjectDetectionInfo obj = {0};
        obj.classId = class_id;
        obj.detectionConfidence = roundf(score * 100.0f) / 100.0f;
        obj.left = left;
        obj.top = top;
        obj.width = width;
        obj.height = height;

        objectList.push_back(obj);
    }

    return true;
}
  1. ds_config.txt
[property]
gie-unique-id=1
infer-dims=3;728;728        # CHW
net-scale-factor=0.00392156862745098
network-input-order=0
model-color-format=0   # RGB
num-detected-classes=6   # num of classes
network-mode=2    # fp16
cluster-mode=2
maintain-aspect-ratio=1
symmetric-padding=1
offsets=0;0;0
onnx-file=inference_model.onnx
model-engine-file=inference_model.engine
labelfile-path=labels.txt
output-blob-names=dets;labels
parse-bbox-func-name=NvDsInferRFDETRParser # custom bbox parser function name
custom-lib-path=nvdsinfer_rf_detr_parser.so. # custom bbox parser

[class-attrs-all]
threshold=0.3

[class-attrs-5]
pre-cluster-threshold=0.6
  1. label.txt
person
bicycle
car
motorcycle
airplane
bus
  1. inference
gst-launch-1.0 filesrc location=videos/1.mp4 ! decodebin ! queue ! nvvideoconvert ! capsfilter caps="video/x-raw(memory:NVMM),format=NV12" ! mux.sink_0 nvstreammux name=mux batch-size=1 width=1920 height=1080 batched-push-timeout=40000 ! nvinfer config-file-path=ds_config.txt ! nvvideoconvert ! nvdsosd ! nvvideoconvert ! nvv4l2h264enc ! h264parse ! qtmux ! filesink location=videos/output.mp4

@Ben93kie
Copy link

ok, wow thx @Willjay90 I'll try it out tmrw

@DatSplit
Copy link

Good evening @ZarCS,

In PR #175 there is an example on how to convert a model to ONNX/TensorRT and deploy it for real-time inferencing in Python.

@Willjay90
Copy link

Hi @Ben93kie,
I updated the parser code based off of the postprocess step in the repo.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
exports Model exports (ONNX, TensorRT, TFLite, etc.)
Projects
None yet
Development

No branches or pull requests

9 participants