Skip to content

Fix overflow indexing in causal_conv1d kernel #20938

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 14, 2025

Conversation

tdoublep
Copy link
Member

@tdoublep tdoublep commented Jul 14, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

This PR fixes an indexing overflow issue that can occur when running mamba-based models in V1. Before V1 integration, the value of conv_state_batch_coord could vary from 0 to max_num_seqs-1. However, with V1 integration it can now vary between 0 and num_gpu_blocks-1. Depending on the model and deployment, the number of blocks can be orders of magnitude higher and I'm running into overflow issues unless we explicitly cast to tl.int64.

Please note I already made the equivalent fix to the selective_state_update as part of #20016. I forgot to make the equivalent change to this kernel, my bad.

@tlrmchlsmth @cyang49 @thoangtrvn

Test Plan

It is actually not that easy to catch this with a test because you need a lot of blocks and need to let it run for a while until the scheduler starts picking off blocks with a high enough index to cause the overflow. I think it could be good to write some tests to try and catch these kernel overflow issues in general, but I would advocate for getting this fix in now.

Test Result

n/a

(Optional) Documentation Update

n/a

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @tdoublep, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical indexing overflow bug in the causal_conv1d kernel, which is part of the Mamba model's operations. The issue arose due to an insufficient integer type for batch coordinates in continuous batching scenarios, particularly after V1 integration expanded the possible range of these coordinates. My change ensures robust indexing by explicitly casting the coordinate to a 64-bit integer type, preventing crashes and ensuring correct model execution under high load.

Highlights

  • Bug Fix: Integer Overflow: I've addressed an integer overflow issue within the _causal_conv1d_fwd_kernel function. Specifically, the conv_state_batch_coord variable, which is used for indexing, was not explicitly cast to tl.int64. With V1 integration, the potential range of this coordinate increased significantly (from max_num_seqs-1 to num_gpu_blocks-1), leading to overflow in certain scenarios, particularly with large numbers of GPU blocks.
  • Type Casting for Robustness: The fix involves explicitly casting the loaded conv_state_batch_coord value to tl.int64 when IS_CONTINUOUS_BATCHING is enabled. This ensures that the index can correctly accommodate larger values, preventing potential out-of-bounds access or incorrect memory addressing.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical integer overflow issue in the causal_conv1d kernel by casting the conv_state_batch_coord to tl.int64. The fix is correct and well-explained.

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 14, 2025
@tlrmchlsmth
Copy link
Collaborator

Thanks for the fix!

@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) July 14, 2025 18:30
@thoangtrvn
Copy link
Contributor

Thanks so much, @tdoublep .

@thoangtrvn
Copy link
Contributor

@tdoublep : Could you please adds the fix to that indexing issue in the second kernel as well.

@tlrmchlsmth tlrmchlsmth merged commit 86f3ac2 into vllm-project:main Jul 14, 2025
77 checks passed
py-andy-c pushed a commit to py-andy-c/vllm that referenced this pull request Jul 14, 2025
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
py-andy-c pushed a commit to py-andy-c/vllm that referenced this pull request Jul 14, 2025
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
patrickvonplaten pushed a commit to patrickvonplaten/vllm that referenced this pull request Jul 15, 2025
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants