Skip to content

YOLOv12 Compilation from PyTorch #22022

@jishminor

Description

@jishminor

Hey Folks,

New to IREE and MLIR based compilation workflows for ML. I've been trying to get my hands dirty by compiling a few models using iree-turbine aot which in the past I successfully lowered down to ExecuTorch Edge dialect in the past to compare the process.

I'm hitting some snags using AoT.

Here is my sample program for exporting:

from typing import Dict, Iterator, Optional, Tuple

import numpy as np
import torch
import iree.turbine.aot as aot
import iree.runtime as ireert
from ultralytics import YOLO

from torch.export import export, ExportedProgram, Dim

def main(
    model_name: str,
    input_dims: Tuple[int, int],
    output: str,
    device: str,
    val_dataset_yaml_path: Optional[str],
):
    # Load the model
    model = YOLO(model_name)

    # Ensure no in-place activations (e.g., SiLU(inplace=True)) are used, as they
    # can lead to in-place ops like aten.silu_ in the exported graph which the
    # IREE importer rejects. We toggle them off on all present models.
    def _disable_inplace_activations(module: torch.nn.Module):
        for m in module.modules():
            if isinstance(m, torch.nn.SiLU):
                m.inplace = False

    _disable_inplace_activations(model.model)

    # Initialize preprocessing/predictor by running a dummy prediction
    H, W = int(input_dims[0]), int(input_dims[1])
    np_dummy_tensor = np.ones((H, W, 3), dtype=np.float32)
    model.predict(np_dummy_tensor, imgsz=(H, W), device="cpu")

    pt_model = model.model.to(torch.device("cpu"))

    def transform_fn(frame):
        return model.predictor.preprocess([frame])

    sample_input = transform_fn(np_dummy_tensor)

    # Export and compile with iree-turbine
    with torch.inference_mode():
        torch_exported_program: ExportedProgram = export(
            pt_model, (sample_input,), strict=False
        )

    print("Exported program successfully.")
    print("Exported Program Details:")
    print("Graph Module:", torch_exported_program.graph_module.print_readable())
    print(torch_exported_program.module_call_graph)
    print("\n\n")

    # Now export to IREE from the ExportedProgram
    iree_export_output = aot.export(torch_exported_program)

    iree_export_output.save_mlir("yolo12.mlir")

    iree_export_output.compile(save_to=output)
    print(f"Compiled IREE module saved to: {output}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Export Ultralytics YOLO12 models to IREE with iree-turbine."
    )
    parser.add_argument(
        "--model_name",
        type=str,
        default="yolo12s",
        choices=["yolo12n", "yolo12s", "yolo12m", "yolo12l", "yolo12x"],
        help="Ultralytics yolo12 model name.",
    )
    parser.add_argument(
        "--input_dims",
        type=eval,
        default=[640, 640],
        help="Input model dimensions [height, width]. Default [640, 640]",
    )
    parser.add_argument(
        "--output",
        type=str,
        default="yolo12.vmfb",
        help="Output IREE VM flatbuffer path",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="local-task",
        help="IREE runtime device (e.g., local-task)",
    )

    args = parser.parse_args()

    # Run the main function with parsed arguments
    main(
        model_name=args.model_name,
        input_dims=args.input_dims,
        output=args.output,
        device=args.device,
        val_dataset_yaml_path=False,
    )

Exporting using torch.export works fine, passing the ExportedProgram to the IREE compiler is where I get an error:

error: failed to legalize operation 'torch.constant.int'
Traceback (most recent call last):
  File "iree_playground/examples/torch/yolov12_export.py", line 209, in <module>
    main(
    ~~~~^
        model_name=args.model_name,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ...<3 lines>...
        val_dataset_yaml_path=False,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "code/iree_playground/examples/torch/yolov12_export.py", line 76, in main
    iree_export_output.compile(save_to=output)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^
  File ".pyenv/versions/iree/lib/python3.13/site-packages/iree/turbine/aot/exporter.py", line 160, in compile
    raise RuntimeError("Compilation failed: See diagnostics")
RuntimeError: Compilation failed: See diagnostics

Run with python environment activated:

python yolov12_export.py --model_name yolo12s --input_dims [640,640] --output yolo12.vmfb

To reproduce I'm on python 3.13 on a mac M4 with torch 2.8.0 and ultralytics 8.3.97

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions