Skip to content

How can I turn on the Warp specialization in the grouped_gemm op? #4321

@linlang1837

Description

@linlang1837

The below are my codes.

from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
    grouped_gemm,
)    
a = torch.randn(M * G, K, dtype=torch.bfloat16, device=device)
b = torch.randn(N * G, K, dtype=torch.bfloat16, device=device)
m_sizes = torch.full((G, ), M, device=device, dtype=torch.int32)
for _ in range(num_warmup):
    result = grouped_gemm(
        a,
        b,
        m_sizes,
    )

When I run the codes, the following warning is displayed:
fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py:975: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.

I want to know if the only way to turn on warp specialization is to build Triton from source in the specific branch?
And the versions of the relevant packages are below:
fbgemm_gpu_genai 2025.6.10+cu126
pytorch-triton 3.3.1+gitc8757738
torch 2.8.0.dev20250611+cu126
triton 3.3.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions