Skip to content

[Performance] Performance improvements in non-blockwise fp8 CUTLASS MoE #20762

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

ElizaWszola
Copy link
Contributor

@ElizaWszola ElizaWszola commented Jul 10, 2025

This PR introduces a couple performance improvements to non-blockwise fp8 CUTLASS MoE. The improvements are:

  • Bringing back pre-calculation of ab_strides and c_strides. It had been disabled due to PPLX codepath compatibility issues which are now resolved.
  • Faster kernels for shuffling hidden_states, input scales and outputs of the function that executes CUTLASS MoE layer.

Benchmarks

Execution times of functioncutlass_moe_fp8 times in microseconds (μs):

-------------------------------------------------------------------------------------------------
                                                | current grouped_gemm_moe | new grouped_gemm_moe
-------------------------------------------------------------------------------------------------
num_experts=8, topk=2, MKN=((1, 4096, 28672))   | 8480.9                   | 8143.1
num_experts=8, topk=2, MKN=((4, 4096, 28672))   | 17199.9                  | 16930.2
num_experts=8, topk=2, MKN=((8, 4096, 28672))   | 26038.6                  | 25675.8
num_experts=8, topk=2, MKN=((16, 4096, 28672))  | 26167.1                  | 25848.8
num_experts=8, topk=2, MKN=((32, 4096, 28672))  | 26402.4                  | 26026.9
num_experts=8, topk=2, MKN=((64, 4096, 28672))  | 26816.0                  | 26389.1
num_experts=8, topk=2, MKN=((128, 4096, 28672)) | 27674.2                  | 27175.4
num_experts=8, topk=2, MKN=((256, 4096, 28672)) | 30751.2                  | 29755.0
num_experts=8, topk=2, MKN=((512, 4096, 28672)) | 38974.5                  | 38203.2
num_experts=8, topk=2, MKN=((1, 14336, 4096))   | 5662.6                   | 5385.5
num_experts=8, topk=2, MKN=((4, 14336, 4096))   | 11083.1                  | 10783.2
num_experts=8, topk=2, MKN=((8, 14336, 4096))   | 12398.0                  | 13571.3
num_experts=8, topk=2, MKN=((16, 14336, 4096))  | 11243.3                  | 13677.9
num_experts=8, topk=2, MKN=((32, 14336, 4096))  | 14217.7                  | 13805.0
num_experts=8, topk=2, MKN=((64, 14336, 4096))  | 14595.0                  | 14073.9
num_experts=8, topk=2, MKN=((128, 14336, 4096)) | 15392.0                  | 14538.9
num_experts=8, topk=2, MKN=((256, 14336, 4096)) | 17839.9                  | 16317.0
num_experts=8, topk=2, MKN=((512, 14336, 4096)) | 25584.6                  | 21823.1
num_experts=64, topk=6, MKN=((1, 2048, 1408))   | 2365.9                   | 2013.4
num_experts=64, topk=6, MKN=((4, 2048, 1408))   | 3615.8                   | 3224.5
num_experts=64, topk=6, MKN=((8, 2048, 1408))   | 4391.0                   | 4359.8
num_experts=64, topk=6, MKN=((16, 2048, 1408))  | 5686.3                   | 5189.1
num_experts=64, topk=6, MKN=((32, 2048, 1408))  | 6670.1                   | 6260.9
num_experts=64, topk=6, MKN=((64, 2048, 1408))  | 6955.2                   | 6528.7
num_experts=64, topk=6, MKN=((128, 2048, 1408)) | 7538.7                   | 6893.7
num_experts=64, topk=6, MKN=((256, 2048, 1408)) | 8498.4                   | 7582.4
num_experts=64, topk=6, MKN=((512, 2048, 1408)) | 10094.3                  | 8767.7
num_experts=32, topk=8, MKN=((1, 1024, 1024))   | 1838.5                   | 1531.2
num_experts=32, topk=8, MKN=((4, 1024, 1024))   | 2479.6                   | 2121.3
num_experts=32, topk=8, MKN=((8, 1024, 1024))   | 2730.8                   | 2351.4
num_experts=32, topk=8, MKN=((16, 1024, 1024))  | 2758.8                   | 2388.1
num_experts=32, topk=8, MKN=((32, 1024, 1024))  | 2818.1                   | 2445.5
num_experts=32, topk=8, MKN=((64, 1024, 1024))  | 2942.6                   | 2534.9
num_experts=32, topk=8, MKN=((128, 1024, 1024)) | 3316.4                   | 2745.9
num_experts=32, topk=8, MKN=((256, 1024, 1024)) | 3991.2                   | 3238.2
num_experts=32, topk=8, MKN=((512, 1024, 1024)) | 5357.5                   | 4314.8
num_experts=40, topk=8, MKN=((1, 1024, 1536))   | 1988.2                   | 1787.0
num_experts=40, topk=8, MKN=((4, 1024, 1536))   | 2928.6                   | 2461.8
num_experts=40, topk=8, MKN=((8, 1024, 1536))   | 3151.5                   | 2930.3
num_experts=40, topk=8, MKN=((16, 1024, 1536))  | 3611.4                   | 3209.7
num_experts=40, topk=8, MKN=((32, 1024, 1536))  | 3687.8                   | 3310.0
num_experts=40, topk=8, MKN=((64, 1024, 1536))  | 3846.8                   | 3438.9
num_experts=40, topk=8, MKN=((128, 1024, 1536)) | 4267.1                   | 3689.6
num_experts=40, topk=8, MKN=((256, 1024, 1536)) | 4986.2                   | 4230.9
num_experts=40, topk=8, MKN=((512, 1024, 1536)) | 6593.6                   | 5527.7

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the performance Performance-related issues label Jul 10, 2025
Copy link

mergify bot commented Jul 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ElizaWszola.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 10, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @ElizaWszola, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of non-blockwise FP8 CUTLASS Mixture-of-Experts (MoE) operations. It achieves this by re-integrating and optimizing the pre-calculation of GEMM strides and by deploying more efficient CUDA kernels for data shuffling within the MoE layer, leading to overall speed improvements.

Highlights

  • Re-enabled Stride Pre-calculation: The ab_strides and c_strides parameters, crucial for efficient GEMM operations in CUTLASS MoE, are now pre-calculated and passed explicitly to the kernels. This re-enables a performance optimization that was previously disabled due to PPLX compatibility issues, which are now resolved.
  • Optimized Data Shuffling Kernels: Introduced and integrated faster CUDA kernels for shuffling input hidden states, their scales, and the final outputs of the MoE layer. A fallback 'slow' kernel is also provided for cases where column alignment prevents the use of the fastest kernel, ensuring robustness.
  • API and Architecture Updates: The run_cutlass_moe_fp8 function and the CutlassExpertsFp8 class have been updated to accept ab_strides and c_strides as explicit arguments. The responsibility for calculating these strides has been centralized within CompressedTensorsW8A8FP8MoEMethod.process_weights_after_loading, streamlining their management.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces performance improvements for non-blockwise fp8 CUTLASS MoE. The main changes include pre-calculating strides and using faster shuffle kernels. The changes are well-contained and affect benchmark, test, and kernel implementation files.

My review identified a potential high-severity bug in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py where a function call is missing required arguments, which could lead to a runtime error. I also found a medium-severity maintainability issue in csrc/moe/moe_permute_unpermute_op.cu related to code duplication. I've provided suggestions to fix both issues. Overall, the changes align with the stated performance goals.

Comment on lines 956 to 977
if self.fused_experts is None:
# If no modular kernel is provided, use cutlass_moe_fp8
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8)
return cutlass_moe_fp8(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
ab_strides1=self.ab_strides1,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.c_strides2,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The call to cutlass_moe_fp8 in the if self.fused_experts is None: block is missing the new stride arguments (ab_strides1, ab_strides2, c_strides1, c_strides2). Since the function signatures for MoE kernels are being updated in this PR to require these strides, this will likely cause a TypeError if this code path is executed. Please pass the stride tensors to the cutlass_moe_fp8 call.

            return cutlass_moe_fp8(
                x,
                layer.w13_weight,
                layer.w2_weight,
                topk_weights,
                topk_ids,
                activation=activation,
                global_num_experts=global_num_experts,
                expert_map=None if self.disable_expert_map else expert_map,
                w1_scale=layer.w13_weight_scale,
                w2_scale=layer.w2_weight_scale,
                ab_strides1=self.ab_strides1,
                ab_strides2=self.ab_strides2,
                c_strides1=self.c_strides1,
                c_strides2=self.c_strides2,
                a1_scale=layer.w13_input_scale,
                a2_scale=layer.w2_input_scale,
            )

Comment on lines +200 to +217
if (num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)) {
// use slow kernel if num_cols can't be aligned to 128 bits
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
shuffleInputRowsKernelSlow<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
dst2src_map.data_ptr<int32_t>(),
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
num_dest_rows, num_cols);
});
} else {
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
dst2src_map.data_ptr<int32_t>(),
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
num_dest_rows, num_cols);
});
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The MOE_DISPATCH macro is duplicated in both the if and else branches. This can be refactored to have a single MOE_DISPATCH call with the conditional logic inside the lambda to improve code readability and maintainability by reducing duplication.

MOE_DISPATCH(input_tensor.scalar_type(), [&] {
  if (num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)) {
    // use slow kernel if num_cols can't be aligned to 128 bits
    shuffleInputRowsKernelSlow<scalar_t><<<blocks, threads, 0, stream>>>(
        reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
        dst2src_map.data_ptr<int32_t>(),
        reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
        num_dest_rows, num_cols);
  } else {
    shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
        reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
        dst2src_map.data_ptr<int32_t>(),
        reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
        num_dest_rows, num_cols);
  }
});

Signed-off-by: ElizaWszola <ewszola@redhat.com>
@mergify mergify bot removed the needs-rebase label Jul 10, 2025
Copy link

mergify bot commented Jul 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ElizaWszola.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 11, 2025
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
@mergify mergify bot removed the needs-rebase label Jul 11, 2025
@ElizaWszola ElizaWszola marked this pull request as ready for review July 11, 2025 05:44
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Copy link

mergify bot commented Jul 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ElizaWszola.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-rebase performance Performance-related issues
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant