Skip to content

[Model] Replace Mamba2 RMSNorm Gated with Fused Triton Kernel #20839

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

cyang49
Copy link
Contributor

@cyang49 cyang49 commented Jul 11, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

This PR improves RMSNormGated performance. The original implementation has unfused operations which results in slow execution time. In the experiments with 16k and 32k prompt length, we found that RMSNormGated layer took as much time as mamba2 SSD computations, much longer than expected.

The fix replaces the current implementation with a fused kernel from the mamba_ssm repo. It applies to TP=1 and ngroups=1 cases only for now. It is likely that similar fixes are needed for broader use cases.

@tdoublep @tlrmchlsmth

Test Plan

Since the RMSNormGated is replaced with a new implementation, the output quality will be tested e2e using lm_eval with gsm8k. The performance gain will be show through benchmark_latency results.

Test Result

Experiments done on H100-80GB
Before results are tested with f29fd8a

benchmark_latency.py

Before

16k

python benchmarks/benchmark_latency.py --model ibm-ai-platform/Bamba-9B-v2 --input-len=16384 --output-len=1 --batch-size=1 --max_num_batched_tokens=16384
Avg latency: 0.6322657451188812 seconds
10% percentile latency: 0.6287146508228034 seconds
25% percentile latency: 0.6306551870657131 seconds
50% percentile latency: 0.6316098154056817 seconds
75% percentile latency: 0.6337889281567186 seconds
90% percentile latency: 0.63668047292158 seconds
99% percentile latency: 0.6390077278809622 seconds

32k

python benchmarks/benchmark_latency.py --model ibm-ai-platform/Bamba-9B-v2 --input-len=32768 --output-len=1 --batch-size=1 --max_num_batched_tokens=32768
Avg latency: 1.2798980366593848 seconds
10% percentile latency: 1.2776596304494887 seconds
25% percentile latency: 1.2785978385945782 seconds
50% percentile latency: 1.2790819348301739 seconds
75% percentile latency: 1.2802647254429758 seconds
90% percentile latency: 1.2838081065099687 seconds
99% percentile latency: 1.286418360499665 seconds

After

16k

python benchmarks/benchmark_latency.py --model ibm-ai-platform/Bamba-9B-v2 --input-len=16384 --output-len=1 --batch-size=1 --max_num_batched_tokens=16384
Avg latency: 0.5812359099897245 seconds
10% percentile latency: 0.5767858538310975 seconds
25% percentile latency: 0.5795957769732922 seconds
50% percentile latency: 0.5808804880362004 seconds
75% percentile latency: 0.5830415590899065 seconds
90% percentile latency: 0.5850747511722147 seconds
99% percentile latency: 0.5883477316750213 seconds

32k

python benchmarks/benchmark_latency.py --model ibm-ai-platform/Bamba-9B-v2 --input-len=32768 --output-len=1 --batch-size=1 --max_num_batched_tokens=32768
Avg latency: 1.1791308337822557 seconds
10% percentile latency: 1.1760888285003603 seconds
25% percentile latency: 1.1774692864855751 seconds
50% percentile latency: 1.178756954614073 seconds
75% percentile latency: 1.180810405407101 seconds
90% percentile latency: 1.1822025504428892 seconds
99% percentile latency: 1.1837539302790538 seconds

Approximately 8~9% latency improvements observed.

lm_eval Bamba-9B

lm_eval --model vllm  --model_args pretrained=ibm-ai-platform/Bamba-9B-v2,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.95 --batch_size auto --trust_remote_code  --cache_requests true --tasks gsm8k

Before

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.4162|±  |0.0136|
|     |       |strict-match    |     5|exact_match|↑  |0.4132|±  |0.0136|

After

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.4246|±  |0.0136|
|     |       |strict-match    |     5|exact_match|↑  |0.4246|±  |0.0136|

(Optional) Documentation Update

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @cyang49, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

I've summarized this pull request as primarily focused on enhancing the performance of the RMSNormGated layer within the Mamba model executor. My analysis shows it achieves this by replacing the existing, less efficient unfused operations with a highly optimized, fused Triton kernel. I expect this change to yield substantial speedups, particularly for longer prompt lengths, by reducing the overhead of multiple discrete GPU operations into a single, efficient kernel.

Highlights

  • Performance Optimization: I see that the existing multi-operation RMSNormGated implementation has been replaced with a single, fused Triton kernel for significant speedup. This addresses performance bottlenecks identified with 16k and 32k prompt lengths.
  • Triton Kernel Integration: I've noted the introduction of a new Triton-based kernel (layernorm_gated.py), adapted from the mamba_ssm repository, to handle the fused RMSNorm and gating logic.
  • Targeted Application: I observe that the new fused kernel is currently applied specifically to TP=1 (Tensor Parallelism) and ngroups=1 configurations, with the PR description indicating that broader use cases might require similar fixes.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces the existing RMSNormGated implementation with a more performant, fused Triton kernel from the mamba_ssm library. The changes are well-contained and correctly implement the fusion for the TP=1 and ngroups=1 case as intended. The new Triton kernel and its wrapper functions are correctly implemented. I have one minor suggestion to improve code clarity by removing an unused parameter from a function signature in the new layernorm_gated.py file.

@cyang49 cyang49 changed the title Replace RMSNorm Gated with fused triton kernel [Model] Replace Mamba2 RMSNorm Gated with Fused Triton Kernel Jul 11, 2025
cyang49 added 2 commits July 15, 2025 16:15
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
@cyang49 cyang49 force-pushed the pr_fused_rmsnormgated branch from d276a6f to 9af4bf9 Compare July 15, 2025 20:58
@cyang49 cyang49 marked this pull request as ready for review July 15, 2025 21:01
lint

Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Yu Chin Fabian Lim <fabian.lim@gmail.com>
Co-authored-by: Yu Chin Fabian Lim <fabian.lim@gmail.com>
@cyang49 cyang49 force-pushed the pr_fused_rmsnormgated branch from c0cff38 to ec0a8c6 Compare July 16, 2025 13:16
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just one comment

Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com>
@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 22, 2025
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) July 22, 2025 13:39
@cyang49
Copy link
Contributor Author

cyang49 commented Jul 22, 2025

Waiting for unit test PR which covers RMSNormGated to be merged, instead of adding a new unit test

@tlrmchlsmth tlrmchlsmth disabled auto-merge July 22, 2025 14:34
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) July 23, 2025 13:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants