Skip to content

❓ [Question] Manually Annotate Quantization Parameters in FX Graph #3522

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
patrick-botco opened this issue May 16, 2025 · 6 comments
Open
Labels
question Further information is requested

Comments

@patrick-botco
Copy link

patrick-botco commented May 16, 2025

❓ Question

is there a way to manually annotate quantization parameters that will be respected throughout torch_tensorrt conversion (e.g. manually adding q/dq nodes, or specifying some tensor metadata) via dynamo? thank you!

@patrick-botco patrick-botco added the question Further information is requested label May 16, 2025
@patrick-botco patrick-botco changed the title ❓ [Question] Quantization IR w/ Dynamo ❓ [Question] Manually Annotate Quantization Parameters in FX Graph May 16, 2025
@patrick-botco
Copy link
Author

cc @narendasan @peri044 maybe? 🙏

@narendasan
Copy link
Collaborator

This should be possible as this is what the tensorrt model optimizer toolkit effectively does. @peri044 or @lanluo-nvidia could maybe give more specific guidance.

@peri044
Copy link
Collaborator

peri044 commented May 19, 2025

We currently use NVIDIA Model optimizer toolkit which inserts quantization nodes within the torch model using quantize API

  1. https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/9c54aa1c47871d0541801a20962996461d805162/modelopt/torch/quantization/model_quant.py#L126
  2. https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/9c54aa1c47871d0541801a20962996461d805162/modelopt/torch/quantization/tensor_quant.py#L229-L243 (definition of custom ops which do the quantization). We have converters for these quantization custom ops (which call Q & DQ apis in TensorRT).

You can also manually insert a quantization custom op by implementing a lowering pass which adds these nodes to the torch.fx.GraphModule and implement/register a custom converter for it. You can append custom metadata to this node by updating node.meta["val"]

  1. https://docs.pytorch.org/TensorRT/contributors/writing_dynamo_aten_lowering_passes.html (existing lowering passes)
  2. https://docs.pytorch.org/TensorRT/contributors/dynamo_converters.html
    This can be done outside Torch-TRT codebase using the decorations listed above to register your lowering pass/ converter.

Please let me know if you have any further questions.

@patrick-botco
Copy link
Author

hey @peri044 , thanks for the response. i tried modelopt -> export on a simple model below. am i using this wrong or missing something obvious? im using non-strict export (strict runs into torch._dynamo.exc.Unsupported: reconstruct: UserDefinedObjectVariable(_DMAttributeManager)), but hitting ValueError: Node type mismatch; expected <class 'tuple'>, but got <class 'torch.Size'>. thanks!

import modelopt.torch.quantization as mtq
import torch
from modelopt.torch.quantization.utils import export_torch_mode


class JustAConv(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 3, 3)

    def forward(self, inputs):
        return self.conv(inputs)


if __name__ == "__main__":
    model = JustAConv().to("cuda").eval()
    sample_input = torch.ones(1, 3, 224, 224).to("cuda")
    quant_cfg = mtq.INT8_DEFAULT_CFG
    mtq.quantize(
        model,
        quant_cfg,
        forward_loop=lambda model: model(sample_input),
    )

    with torch.no_grad():
        with export_torch_mode():
            exported_program = torch.export.export(model, (sample_input,), strict=False)

@lanluo-nvidia
Copy link
Collaborator

@patrick-botco I have tried your example with our latest main, when strict=False it is working as expected.
I guess your error might be related to your specific version.
Could you please let me know your version?

@patrick-botco
Copy link
Author

hey @lanluo-nvidia thanks for checking! here are my pytorch and modelopt versions:

nvidia-modelopt           0.29.0
nvidia-modelopt-core      0.29.0
torch                     2.5.1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants