Skip to content

[ONNX] Add dynamic shapes support (& in-browser inference w/ Transformers.js) #79

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from

Conversation

xenova
Copy link

@xenova xenova commented Mar 29, 2025

Description

This PR adds support for exporting RF-DETR to ONNX with dynamic input and output shapes. This means you can now supply images of variable batch_size, height, and width, both for backbone-only and full exports (provided width and height are divisible by 56).

I have uploaded the checkpoints (along with various quantizations) to the Hugging Face Hub:

Type of change

  • New feature (non-breaking change which adds functionality)

How has this change been tested, please provide a testcase or example of how you tested the change?

Exporting:

from rfdetr import RFDETRBase

model = RFDETRBase()

model.export()
model.export(backbone_only=True)

Similarly for RFDETRLarge.

Testing:

Running with Transformers.js, the model outputs the correct response (i.e., valid model, and valid dynamic shape support). See PR at huggingface/transformers.js#1260.

Any specific deployment considerations

I would recommend further optimizing with the amazing onnxslim library; you can reduce model size quite a bit (eliminating tied weights and redundant ops).

Base:

+--------------------+-----------------------------------------+--------------------------------------+
|     Model Name     |          inference_model.onnx           |       ./final-base/model.onnx        |
+--------------------+-----------------------------------------+--------------------------------------+
|     Model Info     |       Op Set: 17 / IR Version: 8        |      Op Set: 17 / IR Version: 8      |
+--------------------+-----------------------------------------+--------------------------------------+
|     IN: input      |  float32: ('batch_size', 3, 'height',   | float32: ('batch_size', 3, 'height', |
|                    |                'width')                 |               'width')               |
|     OUT: dets      |      float32: ('Concatdets_dim_0',      |   float32: ('batch_size', 300, 4)    |
|                    | 'Concatdets_dim_1', 'Concatdets_dim_2') |                                      |
|    OUT: labels     |      float32: ('Addlabels_dim_0',       |   float32: ('batch_size', 300, 91)   |
|                    |         'Addlabels_dim_1', 91)          |                                      |
+--------------------+-----------------------------------------+--------------------------------------+
|        Add         |                   186                   |                 181                  |
|        And         |                    1                    |                  1                   |
|        Cast        |                   109                   |                  5                   |
|       Concat       |                   155                   |                  88                  |
|      Constant      |                  1068                   |                  0                   |
|  ConstantOfShape   |                   13                    |                  1                   |
|        Conv        |                    9                    |                  9                   |
|        Cos         |                    4                    |                  4                   |
|        Div         |                   77                    |                  47                  |
|       Equal        |                    8                    |                  2                   |
|        Erf         |                   11                    |                  11                  |
|        Exp         |                    3                    |                  3                   |
|       Expand       |                   16                    |                  6                   |
|       Gather       |                   178                   |                 122                  |
|   GatherElements   |                    1                    |                  1                   |
|      Greater       |                    2                    |                  2                   |
|     GridSample     |                    3                    |                  3                   |
| LayerNormalization |                   37                    |                  37                  |
|        Less        |                    1                    |                  1                   |
|       MatMul       |                   135                   |                 135                  |
|        Mul         |                   144                   |                 120                  |
|        Not         |                    3                    |                  3                   |
|        Pow         |                   10                    |                  9                   |
|       Range        |                    2                    |                  2                   |
|     ReduceMax      |                    1                    |                  1                   |
|     ReduceMean     |                   18                    |                  18                  |
|     ReduceSum      |                    4                    |                  4                   |
|        Relu        |                    8                    |                  8                   |
|      Reshape       |                   128                   |                 115                  |
|       Resize       |                    1                    |                  1                   |
|     ScatterND      |                    2                    |                  2                   |
|       Shape        |                   203                   |                  67                  |
|      Sigmoid       |                    8                    |                  8                   |
|        Sin         |                    4                    |                  4                   |
|       Slice        |                   60                    |                  37                  |
|      Softmax       |                   17                    |                  17                  |
|       Split        |                   15                    |                  2                   |
|        Sqrt        |                   42                    |                  9                   |
|      Squeeze       |                   21                    |                  3                   |
|        Sub         |                   19                    |                  14                  |
|        Tile        |                    4                    |                  4                   |
|        TopK        |                    1                    |                  1                   |
|     Transpose      |                   87                    |                  86                  |
|     Unsqueeze      |                   287                   |                 140                  |
|       Where        |                   10                    |                  4                   |
+--------------------+-----------------------------------------+--------------------------------------+
|     Model Size     |                114.70 MB                |              102.78 MB               |
+--------------------+-----------------------------------------+--------------------------------------+
|    Elapsed Time    |                                     35.79 s                                    |
+--------------------+-----------------------------------------+--------------------------------------+

Large:

+--------------------+-----------------------------------------+--------------------------------------+
|     Model Name     |          inference_model.onnx           |       ./final-large/model.onnx       |
+--------------------+-----------------------------------------+--------------------------------------+
|     Model Info     |       Op Set: 17 / IR Version: 8        |      Op Set: 17 / IR Version: 8      |
+--------------------+-----------------------------------------+--------------------------------------+
|     IN: input      |  float32: ('batch_size', 3, 'height',   | float32: ('batch_size', 3, 'height', |
|                    |                'width')                 |               'width')               |
|     OUT: dets      |      float32: ('Concatdets_dim_0',      |   float32: ('batch_size', 300, 4)    |
|                    | 'Concatdets_dim_1', 'Concatdets_dim_2') |                                      |
|    OUT: labels     |      float32: ('Addlabels_dim_0',       |   float32: ('batch_size', 300, 91)   |
|                    |         'Addlabels_dim_1', 91)          |                                      |
+--------------------+-----------------------------------------+--------------------------------------+
|        Add         |                   218                   |                 209                  |
|        And         |                    1                    |                  1                   |
|        Cast        |                   118                   |                  9                   |
|       Concat       |                   176                   |                 102                  |
|      Constant      |                  1229                   |                  0                   |
|  ConstantOfShape   |                   21                    |                  2                   |
|        Conv        |                   21                    |                  21                  |
|   ConvTranspose    |                    4                    |                  4                   |
|        Cos         |                    4                    |                  4                   |
|        Div         |                   93                    |                  63                  |
|       Equal        |                   15                    |                  2                   |
|        Erf         |                   11                    |                  11                  |
|        Exp         |                    3                    |                  3                   |
|       Expand       |                   27                    |                  11                  |
|       Gather       |                   187                   |                 127                  |
|   GatherElements   |                    1                    |                  1                   |
|      Greater       |                    2                    |                  2                   |
|     GridSample     |                    6                    |                  6                   |
| LayerNormalization |                   37                    |                  37                  |
|        Less        |                    1                    |                  1                   |
|       MatMul       |                   135                   |                 135                  |
|        Mul         |                   179                   |                 146                  |
|        Not         |                    3                    |                  3                   |
|        Pow         |                   23                    |                  22                  |
|       Range        |                    4                    |                  4                   |
|     ReduceMax      |                    1                    |                  1                   |
|     ReduceMean     |                   44                    |                  44                  |
|     ReduceSum      |                    4                    |                  4                   |
|        Relu        |                   12                    |                  12                  |
|      Reshape       |                   144                   |                 125                  |
|       Resize       |                    1                    |                  1                   |
|     ScatterND      |                    4                    |                  4                   |
|       Shape        |                   219                   |                  71                  |
|      Sigmoid       |                   16                    |                  16                  |
|        Sin         |                    4                    |                  4                   |
|       Slice        |                   72                    |                  41                  |
|      Softmax       |                   17                    |                  17                  |
|       Split        |                   23                    |                  5                   |
|        Sqrt        |                   55                    |                  22                  |
|      Squeeze       |                   42                    |                  6                   |
|        Sub         |                   39                    |                  29                  |
|        Tile        |                    4                    |                  4                   |
|        TopK        |                    1                    |                  1                   |
|     Transpose      |                   91                    |                  90                  |
|     Unsqueeze      |                   315                   |                 149                  |
|       Where        |                   17                    |                  4                   |
+--------------------+-----------------------------------------+--------------------------------------+
|     Model Size     |                477.79 MB                |              463.69 MB               |
+--------------------+-----------------------------------------+--------------------------------------+
|    Elapsed Time    |                                     86.97 s                                    |
+--------------------+-----------------------------------------+--------------------------------------+

Docs

  • Docs updated? What were the changes:

@CLAassistant
Copy link

CLAassistant commented Mar 29, 2025

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

@xenova xenova changed the title [ONNX] Add dynamic shapes support [ONNX] Add dynamic shapes support (+ in-browser inference w/ Transformers.js) Mar 29, 2025
@xenova xenova changed the title [ONNX] Add dynamic shapes support (+ in-browser inference w/ Transformers.js) [ONNX] Add dynamic shapes support (& in-browser inference w/ Transformers.js) Mar 29, 2025
@SkalskiP
Copy link
Collaborator

SkalskiP commented Mar 31, 2025

Hi @xenova 👋🏻 this looks very interesting! Could I ask you to accept the CLA? Without it, I won't be able to merge the PR.

I understand that now, after the changes to the export, it will be possible to perform batch inference with any number of images. These images need to have the same width and height, and the width and height must be divisible by 56?

Additionally, since the Roboflow app relies on the ONNX export output, I'll need to sync with them before merging this PR.

@xenova
Copy link
Author

xenova commented Mar 31, 2025

I understand that now, after the changes to the export, it will be possible to perform batch inference with any number of images. These images need to have the same width and height, and the width and height must be divisible by 56?

That is correct!

Additionally, since the Roboflow app relies on the ONNX export output, I'll need to sync with them before merging this PR.

Sounds good :) The default export code will produce models with the same signature as before. However, the models I've uploaded to the HF Hub just have some of the input and output node names updated to fit with the transformers/transformers.js standards. So, if you'd like them to be 100% backwards compatible, you can just re-export with the steps outlined in the README.

Could I ask you to accept the CLA? Without it, I won't be able to merge the PR.

Sure thing! I'm having some issues with the link given by the bot above, but according to CONTRIBUTING.md, it should be okay to add a comment stating:

I have read the CLA Document and I sign the CLA.

Let me know if that works!

@SkalskiP
Copy link
Collaborator

SkalskiP commented Apr 1, 2025

@xenova I believe that needs to be a separate comment. "I have read the CLA Document and I sign the CLA." But without the ". ;)

@xenova
Copy link
Author

xenova commented Apr 1, 2025

I have read the CLA Document and I sign the CLA.

@capjamesg
Copy link
Collaborator

@xenova our automated system is having trouble right now. We accept #79 (comment) as agreement so your PR can be merged.

@SkalskiP
Copy link
Collaborator

SkalskiP commented Apr 1, 2025

@Matvezy it's the ONNX PR we spoke about yesterday. any chance you could take a look and confirm there are no blockers for merging?

@capjamesg
Copy link
Collaborator

capjamesg commented Apr 1, 2025

@xenova Can you try to accept the CLA again at https://cla-assistant.io/roboflow/rf-detr?pullRequest=79? Apologies for the inconvenience.

@isaacrob-roboflow
Copy link
Collaborator

hi! pulling out the precomputed interpolated positional embeddings will make the onnx graph slower for any given image size, which is why we precomputed them in the first place. have you tested latency compared to the existing version?

@xenova
Copy link
Author

xenova commented Apr 2, 2025

CLA should be good now! 🤗 Let me know if I need to sign again or anything 👍

image

hi! pulling out the precomputed interpolated positional embeddings will make the onnx graph slower for any given image size, which is why we precomputed them in the first place. have you tested latency compared to the existing version?

Latency tests for exactly 560x560 input sizes would be interesting to see if anyone has an environment to do benchmarking (unfortunately, latency tests for other dimensions wouldn't be possible). Another option could be to pre-compute 560x560 positional embeddings and using that as the default, but the problem then would be that any use of other input sizes would be slower.

@isaacrob-roboflow
Copy link
Collaborator

You don't need the exact hardware to do a latency test, although of course it won't be comparable to the official result without it. You can use other hardware to provide SOME evidence that there is or is not a slowdown. But without evidence that this does not cause a slowdown (and it is likely to), this will not be merged as-is.

One other option is to use a flag to enable it on export that is default false. But still I'd like to see numbers.

We could run the benchmark at some point ourselves but it will be a bit before we have the bandwidth.

@SkalskiP SkalskiP changed the base branch from main to develop April 3, 2025 08:53
@capjamesg
Copy link
Collaborator

@xenova Can you try to sign the CLA again? Apologies for all of the inconvenience. Of note, you will need JavaScript enabled to sign the CLA.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants