Skip to content

Enable V1 for Hybrid SSM/Attention Models #20016

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Jul 4, 2025

Conversation

tdoublep
Copy link
Member

@tdoublep tdoublep commented Jun 24, 2025

Purpose

This PR enables V1 for models that use both SSM and attention layers.

Related RFCs: #18571 #11382

cc @heheda12345 @tlrmchlsmth

Implementation

The current hybrid cache allocator implementation assumes that the page size is the same across all KV cache groups . Therefore, we need to ensure that the page size for the mamba groups is the the same as the page size for the attention group. This can achieved by:

  1. Padding the mamba page size so that it is a multiple of 16*attn_page_size. The factor of 16 is needed since we generally must have attention block size be a multiple of 16.
  2. Asking the user to set the attention block size (e.g., block_size in config) to a large enough value to ensure the page sizes match. I've currently implemented this by explicitly asking the user to do it, but it could be automated.

Some other details:

  • The block size for the different KV cache groups is now different, but this doesn't seem to require huge changes. In fact, it is already mostly supported. Some minor changes were needed in _get_kv_cache_config_uniform_page_size.
  • Because the Mamba2AttentionMetadataBuilder uses the same reorder_batch logic as the FlashInferMetadataBuilder, it is required to use the FlashInfer backend for now.
  • As for the pure SSM models in V1 right now: prefix caching is not supported and enforce eager must be used.
  • We also need to change layout of the views into the shared KVCacheTensor for the mamba layers to ensure that writing attention blocks does not corrupt mamba blocks and vice-versa (see figures below).
  • Fix int32 overflow in mamba_ssm kernel.

Models that need changes

  • Bamba
  • Zamba
  • Nemotron-H
  • GraniteMoeHybrid
  • Falcon H1

Alternative approaches considered:

The main alternative here would be to relax the constraint that the page size across the KV cache groups must be equal. This seems like a more major change, and introduces complexity in terms of how to manage memory fragmentation.

The main downside of the approach here is that we need to use such large block sizes for the attention layers. This could have performance implications. It might not matter that much for models like Granite 4.0 that only user a few attention layers, but might be more significant for models like Falcon H1 which uses the same ratio of attention to mamba. Benchmark results for both of these models look good though (see below).


Below are some figures to illustrate the main changes of this PR

Alignment of the page size:

Before:
image

After:
image

View into the KVCacheTensor for mamba layers:

Before (does not break anything because we don't actually share KV Cache tensors between attention and mamba yet):

image

After:

image

Test Plan

I've added V1 tests for :

ibm-ai-platform/Bamba-9B-v1
Zyphra/Zamba2-1.2B-instruct
nvidia/Nemotron-H-8B-Base-8K
ibm-granite/granite-4.0-tiny-preview
tiiuae/Falcon-H1-0.5B-Base

which all pass locally on my H100 machine.

I tried to also add a test using hmellor/tiny-random-BambaForCausalLM but that model uses a head size which isn't support by FlashInfer (see above).

Test Result

Passing.

(Optional) Documentation Update

Benchmarking

Granite 4.0 Tiny

Client:

python benchmark_serving.py --model ibm-granite/granite-4.0-tiny-preview \
                            --dataset-name sharegpt \
                            --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
                            --ignore-eos

Server (V0):

vllm serve ibm-granite/granite-4.0-tiny-preview \
    --no-enable-prefix-caching \
    --disable-log-requests

result:

============ Serving Benchmark Result ============
Successful requests:                     983       
Benchmark duration (s):                  60.64     
Total input tokens:                      235329    
Total generated tokens:                  222483    
Request throughput (req/s):              16.21     
Output token throughput (tok/s):         3668.81   
Total Token throughput (tok/s):          7549.45   
---------------Time to First Token----------------
Mean TTFT (ms):                          22040.42  
Median TTFT (ms):                        20313.08  
P99 TTFT (ms):                           48121.13  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          62.48     
Median TPOT (ms):                        62.10     
P99 TPOT (ms):                           129.16    
---------------Inter-token Latency----------------
Mean ITL (ms):                           54.18     
Median ITL (ms):                         59.16     
P99 ITL (ms):                            139.23    
==================================================

Server (V1):

VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=FLASHINFER vllm serve ibm-granite/granite-4.0-tiny-preview \
    --enforce-eager \
    --block-size 400 \
    --no-enable-prefix-caching \
    --disable-log-requests

result:

============ Serving Benchmark Result ============
Successful requests:                     983       
Benchmark duration (s):                  54.97     
Total input tokens:                      235329    
Total generated tokens:                  222483    
Request throughput (req/s):              17.88     
Output token throughput (tok/s):         4047.24   
Total Token throughput (tok/s):          8328.17   
---------------Time to First Token----------------
Mean TTFT (ms):                          8630.11   
Median TTFT (ms):                        8742.72   
P99 TTFT (ms):                           12645.55  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          124.66    
Median TPOT (ms):                        79.82     
P99 TPOT (ms):                           308.56    
---------------Inter-token Latency----------------
Mean ITL (ms):                           67.41     
Median ITL (ms):                         48.29     
P99 ITL (ms):                            316.95    
==================================================

Falcon-H1-0.5B-Base

Client:

python benchmark_serving.py --model tiiuae/Falcon-H1-0.5B-Base  \
                            --dataset-name sharegpt \
                            --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
                            --ignore-eos

Server (V0):

vllm serve  tiiuae/Falcon-H1-0.5B-Base  \
    --no-enable-prefix-caching \
    --disable-log-requests

result:

============ Serving Benchmark Result ============
Successful requests:                     983       
Benchmark duration (s):                  46.21     
Total input tokens:                      234145    
Total generated tokens:                  227965    
Request throughput (req/s):              21.27     
Output token throughput (tok/s):         4932.78   
Total Token throughput (tok/s):          9999.28   
---------------Time to First Token----------------
Mean TTFT (ms):                          15951.45  
Median TTFT (ms):                        14284.85  
P99 TTFT (ms):                           38578.74  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          52.51     
Median TPOT (ms):                        51.35     
P99 TPOT (ms):                           142.91    
---------------Inter-token Latency----------------
Mean ITL (ms):                           43.82     
Median ITL (ms):                         33.28     
P99 ITL (ms):                            92.41     
==================================================

Server (V0) EAGER:

vllm serve  tiiuae/Falcon-H1-0.5B-Base  \
    --no-enable-prefix-caching \
    --disable-log-requests \
    --enforce-eager

result:

============ Serving Benchmark Result ============
Successful requests:                     983       
Benchmark duration (s):                  79.42     
Total input tokens:                      234145    
Total generated tokens:                  227965    
Request throughput (req/s):              12.38     
Output token throughput (tok/s):         2870.21   
Total Token throughput (tok/s):          5818.24   
---------------Time to First Token----------------
Mean TTFT (ms):                          19809.95  
Median TTFT (ms):                        17679.59  
P99 TTFT (ms):                           50738.95  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          75.34     
Median TPOT (ms):                        71.15     
P99 TPOT (ms):                           214.25    
---------------Inter-token Latency----------------
Mean ITL (ms):                           64.92     
Median ITL (ms):                         46.68     
P99 ITL (ms):                            105.26    
==================================================

Server (V1) EAGER:

VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=FLASHINFER vllm serve tiiuae/Falcon-H1-0.5B-Base  \
    --enforce-eager \
    --block-size 800 \
    --no-enable-prefix-caching \
    --disable-log-requests

result:

============ Serving Benchmark Result ============
Successful requests:                     983       
Benchmark duration (s):                  38.29     
Total input tokens:                      234145    
Total generated tokens:                  227965    
Request throughput (req/s):              25.67     
Output token throughput (tok/s):         5953.81   
Total Token throughput (tok/s):          12069.03  
---------------Time to First Token----------------
Mean TTFT (ms):                          6042.31   
Median TTFT (ms):                        6135.97   
P99 TTFT (ms):                           7618.73   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          59.17     
Median TPOT (ms):                        40.94     
P99 TPOT (ms):                           140.24    
---------------Inter-token Latency----------------
Mean ITL (ms):                           37.34     
Median ITL (ms):                         33.27     
P99 ITL (ms):                            129.70    
==================================================

Follow ups:

  • More flexible reorder_batch
  • Support (2, num_block) layout
  • Resolve the extremely large block_size problem
  • Automate block_size selection
  • The limitations in [v1] Support mamba2 #19327

tdoublep added 4 commits June 23, 2025 19:14
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @tdoublep, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces V1 compatibility for hybrid SSM/Attention models, specifically addressing the complexities of KV cache management when different layer types have distinct caching requirements. I've implemented logic to align KV cache page sizes and adapted the system to handle varying block sizes across different cache groups, ensuring proper memory allocation and access for these complex architectures.

Highlights

  • V1 Support for Hybrid Models: I've enabled V1 support for models that incorporate both State Space Models (SSM) like Mamba and traditional Attention layers. This is a significant step towards broader model compatibility within the V1 framework.
  • KV Cache Page Size Alignment: A core challenge addressed was ensuring consistent page sizes across different KV cache groups (SSM and Attention). I've implemented a mechanism to pad the Mamba page size to be a multiple of 16 * attn_page_size, which is crucial for the hybrid cache allocator.
  • Dynamic Block Size Handling: The system now supports different block sizes for various KV cache groups. I've updated KVCacheCoordinator and KVCacheManager to accommodate this, making the block size assertion conditional on caching being enabled.
  • Attention Backend Requirement: Due to the reorder_batch logic used by Mamba2AttentionMetadataBuilder, it's currently required to use the FlashInfer backend for hybrid models in V1. This is explicitly set in the test environment.
  • User Configuration for Attention Block Size: Users are now required to set the attention block size (e.g., block_size in config) to a value large enough to ensure page sizes match between Mamba and Attention layers. A ValueError is raised if this condition is not met.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added the v1 label Jun 24, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR enables V1 support for hybrid SSM/Attention models by ensuring page size compatibility between Mamba and Attention layers, using FlashInfer backend, and enforcing eager execution. It includes padding the mamba page size, requiring users to set attention block size, and adding a test case with ibm-ai-platform/Bamba-9B-v1.

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
# Mamba state must be padded to an integer number of
# 16th tokens worth of attention pages
attn_layer_name = next(iter(attn_layers))
attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to clarify here, kv_cache_spec[attn_layer_name].page_size_bytes is for a single token stored in the mamba cache?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly. kv_cache_spec[attn_layer_name].page_size_bytes gives us the size in bytes for an attention page that stores block_size tokens. The goal here is to figure out the size in bytes of an attention page that stores exactly 16 tokens. Hence why on the next line we divide by block_size to normalize it and multiply 16.

Why do we want to know the size in bytes of an attention page that stores 16 tokens? It's because we want to ensure that the mamba page size if padded up to a value that makes it possible for the user to align the attention page size with. Since the user can only set attention block size in multiples of 16, that is why the factor of 16 is needed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment to explain the magic number "16"?
And to confirm, arbitrary block_size that is supported by the attention backend is OK, but as we don't know which block_size each attention backend support, we have to hardcode "16" which is supported by most attention backends.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tlrmchlsmth If I remember correctly, you told me that FlashMLA does not support block_size 16. Can you confirm? If it is true, we may need some other assertion here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And to confirm, arbitrary block_size that is supported by the attention backend is OK, but as we don't know which block_size each attention backend support, we have to hardcode "16" which is supported by most attention backends.

So actually, I think this magic "16" is not necessarily needed. The constraint that the block size must be a multiple of 16 is only coming from FlashAttention backend (which is not compatible with Mamba right now for reasons discussed). I just checked and with FlashInfer it is possible to set the block size to any number.

Still, probably makes sense to keep "16" since we want to support FlashAttention in near future. Do you agree @heheda12345 ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment to explain the magic number "16"?

Done

Comment on lines 2552 to 2555
raise ValueError(
"Attention block size must be increased to "
f"{required_attn_block_size} in order to match "
"the mamba page size")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a fairly reasonable approach, especially for a first pass

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Main question is whether we are OK with vLLM V1 failing under default parameters for hybrid models? If not, we could automatically scale up the attention block size and log what is happening to inform the user, rather than explicitly ask the user to do it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's a better option, paired with logging a warning. But that could also wait for a follow-up

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the printed block_size is not "must be", but a value that we are suggesting.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It needs to be at least this value in order to work though right? I can't really think of practical scenarios when we would want the attention page size to the bigger than the mamba page size. Mamba page size is typically orders of magnitude bigger than attention page size (per token). If the attention page size is bigger, we will need to pad mamba page to align it and waste more space.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed the language in the exception, please take a look.

tdoublep added 4 commits June 26, 2025 09:20
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
tdoublep added 6 commits June 26, 2025 15:38
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Copy link

mergify bot commented Jun 30, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tdoublep.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 30, 2025
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
tdoublep added 3 commits July 1, 2025 20:52
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
@tdoublep
Copy link
Member Author

tdoublep commented Jul 1, 2025

@heheda12345 I have worked through all of your suggestions. When you get a chance, please take another look.

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation looks good to me. I only left some comments related to coding style.

Comment on lines +65 to +69
"ibm-ai-platform/Bamba-9B-v1": 528,
"Zyphra/Zamba2-1.2B-instruct": 80,
"nvidia/Nemotron-H-8B-Base-8K": 528,
"ibm-granite/granite-4.0-tiny-preview": 400,
"tiiuae/Falcon-H1-0.5B-Base": 800,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to call that kernel during the engine's warm up stage instead of the warmup prompt? (I'm OK with left a warning during engine warmup in this PR when strange block_size is used and fix it later)

tdoublep added 4 commits July 4, 2025 11:39
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
@tdoublep
Copy link
Member Author

tdoublep commented Jul 4, 2025

@heheda12345 i've worked through your suggestions - ready for another look when you have time.

@@ -437,7 +444,7 @@ def load_weights(self, weights: Iterable[tuple[str,


class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
IsHybrid, SupportsV0Only, SupportsQuant):
IsHybrid, SupportsQuant):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the Supported Models page and also the V1 Guide accordingly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would propose we do that after the follow-up PR to automate the attention block size configuration. Until we do that, it is not possible for the user to pass non-standard attention block size via vllm serve CLI since it will only accept the values defined here. Offline inference does allow it though.

Please not that vLLM will still not use V1 by default for these models until we remove this logic.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@heheda12345 what do you think on this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@heheda12345 are you ok with this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that we can update the doc after fixing the block size problem. @DarkLight1337 WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, let's get this merged first then

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, i will prioritize that follow up to get V1 online mode working for these models

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the great work and looking for the follow-ups. Can you update the document as suggested by @DarkLight1337 ?

@heheda12345 heheda12345 added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 4, 2025
@heheda12345 heheda12345 enabled auto-merge (squash) July 4, 2025 13:38
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
auto-merge was automatically disabled July 4, 2025 15:10

Head branch was pushed to by a user without write access

@heheda12345 heheda12345 enabled auto-merge (squash) July 4, 2025 16:21
@heheda12345 heheda12345 merged commit 2f35a02 into vllm-project:main Jul 4, 2025
71 checks passed
sfeng33 pushed a commit to sfeng33/vllm that referenced this pull request Jul 6, 2025
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Stanislaw Wozniak <stw@zurich.ibm.com>
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
huydhn pushed a commit to huydhn/vllm that referenced this pull request Jul 8, 2025
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Stanislaw Wozniak <stw@zurich.ibm.com>
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
Chen-zexi pushed a commit to Chen-zexi/vllm that referenced this pull request Jul 13, 2025
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Stanislaw Wozniak <stw@zurich.ibm.com>
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
patrickvonplaten pushed a commit to patrickvonplaten/vllm that referenced this pull request Jul 15, 2025
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Stanislaw Wozniak <stw@zurich.ibm.com>
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mamba/ssm ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants