Skip to content

[BugFix] fix 3 issues: (1) using metadata for causal-conv1d, (2) indexing overflow in v1 vLLM, and (3) init_states in v0 #20838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 15, 2025

Conversation

thoangtrvn
Copy link
Contributor

@thoangtrvn thoangtrvn commented Jul 11, 2025

Purpose

This PR fixes

Test Plan

In the main, running this is ok

export MODEL_NAME=mistralai/Mamba-Codestral-7B-v0.1
lm_eval --model vllm     --model_args pretrained=${MODEL_NAME},tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.95 --batch_size auto --trust_remote_code  --tasks gsm8k --limit=100 --device cuda:0 

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.47|±  |0.0502|
|     |       |strict-match    |     5|exact_match|↑  | 0.47|±  |0.0502|

In the main, running this would trigger below error

export MODEL_NAME=ibm-ai-platform/Bamba-9B-v2
lm_eval --model vllm     --model_args pretrained=${MODEL_NAME},tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.95 --batch_size auto --trust_remote_code  --tasks gsm8k --limit=100 --device cuda:0 

[rank0]:   File "/net/storage149/mnt/md0/tmhoangt/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/net/storage149/mnt/md0/tmhoangt/vllm_tuan/vllm_tuan/vllm/vllm/model_executor/models/bamba.py", line 502, in forward
[rank0]:     hidden_states = self.model(input_ids, positions, mamba_cache_params,
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/net/storage149/mnt/md0/tmhoangt/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/net/storage149/mnt/md0/tmhoangt/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/net/storage149/mnt/md0/tmhoangt/vllm_tuan/vllm_tuan/vllm/vllm/model_executor/models/bamba.py", line 352, in forward
[rank0]:     hidden_states, residual = layer(
[rank0]:                               ^^^^^^
[rank0]:   File "/net/storage149/mnt/md0/tmhoangt/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/net/storage149/mnt/md0/tmhoangt/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/net/storage149/mnt/md0/tmhoangt/vllm_tuan/vllm_tuan/vllm/vllm/model_executor/models/bamba.py", line 126, in forward
[rank0]:     hidden_states = self.mamba(hidden_states, mamba_cache_params,
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/net/storage149/mnt/md0/tmhoangt/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/net/storage149/mnt/md0/tmhoangt/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/net/storage149/mnt/md0/tmhoangt/vllm_tuan/vllm_tuan/vllm/vllm/model_executor/custom_op.py", line 44, in forward
[rank0]:     return self._forward_method(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/net/storage149/mnt/md0/tmhoangt/vllm_tuan/vllm_tuan/vllm/vllm/model_executor/layers/mamba/mamba_mixer2.py", line 650, in forward_cuda
[rank0]:     initial_states = torch.where(
[rank0]:                      ^^^^^^^^^^^^
[rank0]: RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0

The reason is that in the codepath for v0 vLLM of mamba-mixer2, has_initial_states_p is not actually the prefill-only data; but it is the data of mixed prefill-decode batch. Therefore, we need to apply has_initial_states_p[:num_prefills,...]

With the PR, running this is ok now:

export MODEL_NAME=ibm-ai-platform/Bamba-9B-v2
lm_eval --model vllm     --model_args pretrained=${MODEL_NAME},tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.95 --batch_size auto --trust_remote_code  --tasks gsm8k --limit=100 --device cuda:0 

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.37|±  |0.0485|
|     |       |strict-match    |     5|exact_match|↑  | 0.36|±  |0.0482|

Test Result

(Optional) Documentation Update

vLLM

Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @thoangtrvn, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses two critical bug fixes related to the Mamba model implementation within vLLM. It ensures that the causal convolution operation correctly leverages metadata for performance and resolves an issue with the initialization of states for v0 vLLM models, preventing runtime errors during model evaluation with certain architectures like Bamba.

Highlights

  • Causal Convolution Metadata Usage: The causal_conv1d kernel in mamba_mixer2.py has been updated to correctly utilize mamba2_metadata for accelerated computation. Previously, this metadata was not being passed or used effectively, impacting performance.
  • V0 vLLM Initial States Correction: Fixed a bug in the construction of initial_states for v0 vLLM models within mamba_mixer2.py. The has_initial_states_p tensor was incorrectly assumed to be prefill-only, leading to a RuntimeError due to tensor size mismatch. The logic now correctly slices has_initial_states_p to [:num_prefills] when VLLM_USE_V1 is false, ensuring proper state initialization for mixed prefill-decode batches.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR fixes the causal-conv1d kernel by utilizing metadata for accelerated computation and resolves a tensor shape mismatch in constructing initial_states for v0 vLLM, which was causing a runtime error. The changes effectively address the described issues.

@thoangtrvn
Copy link
Contributor Author

test_hybrid.py passed

image

Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 14, 2025
@thoangtrvn
Copy link
Contributor Author

@cyang49 pointed out that the zeros_like leads to 84us overhead in E2E running, so we can replace with a cheaper one.

@thoangtrvn
Copy link
Contributor Author

Added indexing fix to the second kernel, similar to what proposed in the recent PR #20938, when the kernel is used in v1 VLLM which can have a much larger cache size and indexing can be overflow in 32-bit.

@thoangtrvn thoangtrvn changed the title [BugFix] fix two issues: using metadata for causal-conv1d and init_states in v0 [BugFix] fix 3 issues: (1) using metadata for causal-conv1d, (2) indexing overflow in v1 vLLM, and (3) init_states in v0 Jul 15, 2025
@tlrmchlsmth tlrmchlsmth merged commit f29fd8a into vllm-project:main Jul 15, 2025
68 checks passed
hj-mistral pushed a commit to hj-mistral/vllm that referenced this pull request Jul 19, 2025
…xing overflow in v1 vLLM, and (3) init_states in v0 (vllm-project#20838)

Signed-off-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Co-authored-by: Tuan M. Hoang-Trong <tmhoangt@us.ibm.com>
Signed-off-by: Himanshu Jaju <hj@mistral.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants