From f9a5f4e4f01dbf8b4bf41e2a6ce2b48666f51355 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 4 Jul 2025 19:20:30 +0000 Subject: [PATCH 01/11] Automate choice of attention block size; update docs Signed-off-by: Thomas Parnell --- docs/models/supported_models.md | 10 +- docs/usage/v1_guide.md | 15 +- .../models/language/generation/test_hybrid.py | 16 +- vllm/v1/worker/gpu_model_runner.py | 171 ++++++++++++------ 4 files changed, 130 insertions(+), 82 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 23d71fd44525..ff5e9e98b4cf 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -319,7 +319,7 @@ Specified using `--task generate`. | `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ | | `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | | +| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ | | `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | | | `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | | | `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ | @@ -335,7 +335,7 @@ Specified using `--task generate`. | `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | ✅︎ | | `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | | ✅︎ | ✅︎ | -| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | | +| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | @@ -348,7 +348,7 @@ Specified using `--task generate`. | `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | | ✅︎ | ✅︎ | | `GraniteForCausalLM` | Granite 3.0, Granite 3.1, PowerLM | `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | | +| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ | | `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | | `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ | @@ -367,7 +367,7 @@ Specified using `--task generate`. | `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ | | `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | | +| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ | | `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | | ✅︎ | ✅︎ | | `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | | ✅︎ | ✅︎ | | `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ | @@ -392,7 +392,7 @@ Specified using `--task generate`. | `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`etc. | | | | | `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | | -| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | | +| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | |✅︎ | !!! note Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096. diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 82a2710d895c..63a29cda3443 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -83,7 +83,8 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the | **Decoder-only Models** | 🚀 Optimized | | **Encoder-Decoder Models** | 🟠 Delayed | | **Embedding Models** | 🟢 Functional | -| **Mamba Models** | 🚧 WIP ([PR #19327](https://github.com/vllm-project/vllm/pull/19327)) | +| **Mamba Models** | 🟢 Functional | +| **Hybrid Models** | 🟢 Functional | | **Multimodal Models** | 🟢 Functional | vLLM V1 currently excludes model architectures with the `SupportsV0Only` protocol. @@ -104,8 +105,16 @@ to enable simultaneous generation and embedding using the same engine instance i #### Mamba Models -Models using selective state-space mechanisms instead of standard transformer attention (e.g., `MambaForCausalLM`, `JambaForCausalLM`) -will be supported via [PR #19327](https://github.com/vllm-project/vllm/pull/19327). +Models using selective state-space mechanisms instead of standard transformer attention are partially supported. +Models that use Mamba-2 layers (e.g., `Mamba2ForCausalLM`) are supported, but models that use older Mamba-1 layers +(e.g., `MambaForCausalLM`, `JambaForCausalLM`) are not yet suported. Please note that these models currently require +enforcing eager mode and disabling prefix caching in V1. + +#### Hybrid Models + +Models that combined Mamba-2 layers with standard transformer attention layers are supported (e.g., `BambaForCausalLM`, +`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`). Please note that +these models currently require enforcing eager mode and disabling prefix caching in V1. #### Encoder-Decoder Models diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index ecaae3ec1fc4..eba14e64553e 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -61,14 +61,6 @@ "tiiuae/Falcon-H1-0.5B-Base", ] -ATTN_BLOCK_SIZES = { - "ibm-ai-platform/Bamba-9B-v1": 528, - "Zyphra/Zamba2-1.2B-instruct": 80, - "nvidia/Nemotron-H-8B-Base-8K": 528, - "ibm-granite/granite-4.0-tiny-preview": 400, - "tiiuae/Falcon-H1-0.5B-Base": 800, -} - # Avoid OOM MAX_NUM_SEQS = 4 @@ -105,11 +97,6 @@ def test_models( example_prompts, max_tokens, num_logprobs) if model in V1_SUPPORTED_MODELS: - if model in HYBRID_MODELS and model in ATTN_BLOCK_SIZES: - block_size = ATTN_BLOCK_SIZES[model] - else: - block_size = 16 - with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") if model in HYBRID_MODELS: @@ -118,8 +105,7 @@ def test_models( with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, enforce_eager=True, - enable_prefix_caching=False, - block_size=block_size) as vllm_model: + enable_prefix_caching=False) as vllm_model: vllm_v1_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) else: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 57d0c7b50ff5..6f84a19ae2a2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2613,10 +2613,105 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: format. Layers that do not need KV cache are not included. """ - block_size = self.vllm_config.cache_config.block_size + attn_block_size = self.vllm_config.cache_config.block_size + + kv_cache_spec, attn_page_size, mamba_page_size = \ + self._get_kv_cache_spec_for_block_size( + attn_block_size=attn_block_size, + ) + + if (attn_page_size is not None and mamba_page_size is not None + and attn_page_size != mamba_page_size): + + kv_cache_spec = self._align_hybrid_page_size( + attn_page_size, + mamba_page_size, + attn_block_size, + ) + + return kv_cache_spec + + def _align_hybrid_page_size( + self, attn_page_size: int, mamba_page_size: int, + attn_block_size: int) -> dict[str, KVCacheSpec]: + """ + Ensures that the page size for attention layers and mamba layers + are equal. This is achieved by (1) ensuring that the attention + block size is set such that the page size is greater than or + equal to the mamba page size and then (2) padding the mamba page + size to ensure they are precisely equal. + + Args: + attn_page_size: Page size (bytes) for attention layers. + mamba_page_size: Page size (bytes) for mamba layers. + attn_block_size: Block size (tokens) for attention layers. + Returns: + KVCacheSpec: A dictionary mapping layer names to their KV cache + format. Layers that do not need KV cache are not included. + """ + + # check if we need to increase attention block size + if attn_page_size < mamba_page_size: + + # attention page size (for 1 token) + attn_page_size_1 = attn_page_size // attn_block_size + + # some attention backends (e.g. FA) only support setting + # block size to multiple of 16, so let's suggest a value + # that would work (note: FA is currently not compatible + # with mamba layers, use FlashInfer instead). + attn_block_size = 16 * cdiv(mamba_page_size, 16 * attn_page_size_1) + + logger.info( + "Setting attention block size to %d tokens " + "to ensure that attention page size is >= mamba page size.", + attn_block_size) + + mamba_page_size_padded = attn_block_size * attn_page_size_1 + + mamba_padding_pct = 100 * (mamba_page_size_padded - + mamba_page_size) / mamba_page_size + + logger.info( + "Padding mamba page size by %.2f%% to ensure " + "that mamba page size and attention page size are " + "exactly equal.", mamba_padding_pct) + + kv_cache_spec, new_attn_page_size, new_mamba_page_size = \ + self._get_kv_cache_spec_for_block_size( + attn_block_size=attn_block_size, + mamba_page_size_padded=mamba_page_size_padded, + ) + + assert new_attn_page_size == new_mamba_page_size + + return kv_cache_spec + + def _get_kv_cache_spec_for_block_size( + self, + attn_block_size: int, + mamba_page_size_padded: Optional[int] = None, + ) -> tuple[dict[str, KVCacheSpec], Optional[int], Optional[int]]: + """ + Generates the KVCacheSpec by parsing the kv cache format from each + Attention module in the static forward context, assuming a given + block size (in tokens) for attention layers. + + Args: + attn_block_size: Block size (in tokens) for attention layers. + mamba_page_size_padded: Padded page size for mamba layers. + Returns: + KVCacheSpec: A dictionary mapping layer names to their KV cache + format. Layers that do not need KV cache are not included. + Optional[int]: Page size for attention layers. + Optional[int]: Page size for mamba layers. + """ + use_mla = self.vllm_config.model_config.use_mla kv_cache_spec: dict[str, KVCacheSpec] = {} attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + + attn_page_size = None for layer_name, attn_module in attn_layers.items(): if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None: @@ -2634,7 +2729,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: if attn_module.attn_type == AttentionType.DECODER: if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( - block_size=block_size, + block_size=attn_block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, @@ -2642,11 +2737,17 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: use_mla=use_mla) else: kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, + block_size=attn_block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, use_mla=use_mla) + page_size = kv_cache_spec[layer_name].page_size_bytes + if attn_page_size is None: + attn_page_size = page_size + else: + assert page_size == attn_page_size + elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache. @@ -2659,6 +2760,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaMixer2) + mamba_page_size = None if len(mamba_layers) > 0: if self.vllm_config.speculative_config is not None: raise NotImplementedError( @@ -2671,10 +2773,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: "Prefix caching is not supported for Mamba yet.") max_model_len = self.vllm_config.model_config.max_model_len - page_size_padded = self._maybe_pad_mamba_page_size( - attn_layers, mamba_layers, kv_cache_spec, max_model_len, - block_size) - # Set block_size to max_model_len, so that mamba model will always # have only one block in the KV cache. for layer_name, mamba_module in mamba_layers.items(): @@ -2682,57 +2780,12 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: shapes=mamba_module.get_state_shape(), dtype=self.kv_cache_dtype, block_size=max_model_len, - page_size_padded=page_size_padded) - - return kv_cache_spec - - def _maybe_pad_mamba_page_size( - self, - attn_layers: dict[str, Attention], - mamba_layers: dict[str, MambaMixer2], - kv_cache_spec: dict[str, KVCacheSpec], - max_model_len: int, - block_size: int, - ) -> Optional[int]: - """ - Ensure that page size of attention KV cache groups is greater than or - equal to the mamba KV cache groups. If not, we suggest to the user - how to set the attention block size to ensure that it is. - - If the attention page size is strictly greater than the mamba page size, - we pad the mamba page size to make them equal. - - Args: - attn_layers: Attention layers - mamba_layers: Mamba layers - kv_cache_spec: KV cache spec (populated with attention layers) - - Returns: - Optional[int]: Mamba page size with padding (None if no padding). - """ + page_size_padded=mamba_page_size_padded) - if len(attn_layers) == 0: - return None + page_size = kv_cache_spec[layer_name].page_size_bytes + if mamba_page_size is None: + mamba_page_size = page_size + else: + assert page_size == mamba_page_size - attn_layer_name = next(iter(attn_layers)) - attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes - mamba_layer_name = next(iter(mamba_layers)) - mamba_page_size = MambaSpec( - shapes=mamba_layers[mamba_layer_name].get_state_shape(), - dtype=self.kv_cache_dtype, - block_size=max_model_len).page_size_bytes - if attn_page_size < mamba_page_size: - # attention page size (for 16 tokens) - attn_page_size_16 = 16 * attn_page_size // block_size - # some attention backends (e.g. FA) only support setting - # block size to multiple of 16, so let's suggest a value - # that would work (note: FA is currently not compatible - # with mamba layers, use FlashInfer instead). - suggest_attn_block_size = 16 * cdiv(mamba_page_size, - attn_page_size_16) - raise ValueError( - "Attention block size should be increased to at least " - f"{suggest_attn_block_size} in order to match " - "the mamba page size") - - return attn_page_size + return kv_cache_spec, attn_page_size, mamba_page_size From e618204aef0370dc324a873665460d16fedff1e9 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 4 Jul 2025 20:13:44 +0000 Subject: [PATCH 02/11] Fix typo Signed-off-by: Thomas Parnell --- docs/usage/v1_guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 63a29cda3443..99994c9dff63 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -112,7 +112,7 @@ enforcing eager mode and disabling prefix caching in V1. #### Hybrid Models -Models that combined Mamba-2 layers with standard transformer attention layers are supported (e.g., `BambaForCausalLM`, +Models that combine Mamba-2 layers with standard transformer attention layers are supported (e.g., `BambaForCausalLM`, `Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`). Please note that these models currently require enforcing eager mode and disabling prefix caching in V1. From 600ec11cc6089b94096d51e06e4e279a207efec5 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 4 Jul 2025 20:19:09 +0000 Subject: [PATCH 03/11] More verbose logging / asserts Signed-off-by: Thomas Parnell --- vllm/v1/worker/gpu_model_runner.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6f84a19ae2a2..2ebaca6ee8c6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2663,9 +2663,9 @@ def _align_hybrid_page_size( attn_block_size = 16 * cdiv(mamba_page_size, 16 * attn_page_size_1) logger.info( - "Setting attention block size to %d tokens " + "Increasing attention block size from %d to %d tokens " "to ensure that attention page size is >= mamba page size.", - attn_block_size) + self.vllm_config.cache_config.block_size, attn_block_size) mamba_page_size_padded = attn_block_size * attn_page_size_1 @@ -2683,7 +2683,9 @@ def _align_hybrid_page_size( mamba_page_size_padded=mamba_page_size_padded, ) - assert new_attn_page_size == new_mamba_page_size + assert new_attn_page_size == new_mamba_page_size, ( + f"Attention page size ({new_attn_page_size}) does not equal " + "mamba page size ({new_mamba_page_size}) after alignment.") return kv_cache_spec From ec6d840933843fe11e3b0f9a336f9cb518258447 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 4 Jul 2025 20:25:00 +0000 Subject: [PATCH 04/11] Update docs re: FlashInfer Signed-off-by: Thomas Parnell --- docs/usage/v1_guide.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 99994c9dff63..b6e7a99c91b4 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -114,7 +114,8 @@ enforcing eager mode and disabling prefix caching in V1. Models that combine Mamba-2 layers with standard transformer attention layers are supported (e.g., `BambaForCausalLM`, `Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`). Please note that -these models currently require enforcing eager mode and disabling prefix caching in V1. +in V1 these models currently require enforcing eager mode, disabling prefix caching, and using the FlashInfer attention +backend. #### Encoder-Decoder Models From 84daa1238b24f92a60c25b7bf48b23917338e493 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 8 Jul 2025 11:07:37 +0000 Subject: [PATCH 05/11] Address review feedback Signed-off-by: Thomas Parnell --- docs/usage/v1_guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index d587a8a8b36e..d7634223542d 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -83,7 +83,7 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the | **Decoder-only Models** | 🚀 Optimized | | **Encoder-Decoder Models** | 🟠 Delayed | | **Embedding Models** | 🟢 Functional | -| **Mamba Models** | 🚧 WIP (, ) | +| **Mamba Models** | 🟢 (Mamba-2), 🟡 (Mamba-1) | | **Multimodal Models** | 🟢 Functional | vLLM V1 currently excludes model architectures with the `SupportsV0Only` protocol. From 5ea6bedfd4d2893a814bf02ecb7f9f51d73f97e7 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 8 Jul 2025 20:42:59 +0000 Subject: [PATCH 06/11] Second attempt at auto-setting attention block size Signed-off-by: Thomas Parnell --- vllm/model_executor/models/config.py | 176 +++++++++++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 173 +++++++++----------------- 2 files changed, 235 insertions(+), 114 deletions(-) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 552c4b074216..bd8f8ed5ec6f 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -3,7 +3,11 @@ from copy import deepcopy from typing import TYPE_CHECKING +import vllm.envs as envs +from vllm.distributed import divide from vllm.logger import init_logger +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec if TYPE_CHECKING: from vllm.config import VllmConfig @@ -191,10 +195,182 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: } +class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): + + @classmethod + def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int) -> int: + """Compute the increase in group numbers to account for + replication in order to accompany the head shards.""" + + # in the case ngoups % tp_size == 0, this will be zero + if ngroups % tp_size == 0: + return 0 + + # for n_groups == 1, this is exactly tp_size - n_groups + return tp_size - ngroups + + @classmethod + def get_mamba_cache_shape( + cls, vllm_config: "VllmConfig" + ) -> tuple[tuple[int, int], tuple[int, int]]: + + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + + if hasattr(hf_config, "mamba_expand"): + mamba_expand = hf_config.mamba_expand + elif hasattr(hf_config, "expand"): + # nemotron-h + mamba_expand = hf_config.expand + else: + raise ValueError("Cannot find mamba_expand in config.") + + if hasattr(hf_config, "mamba_n_groups"): + mamba_n_groups = hf_config.mamba_n_groups + elif hasattr(hf_config, "mamba_ngroups"): + # zamba2 + mamba_n_groups = hf_config.mamba_ngroups + elif hasattr(hf_config, "n_groups"): + # nemotron-h + mamba_n_groups = hf_config.n_groups + else: + raise ValueError("Cannot find mamba n_groups in config.") + + if hasattr(hf_config, "mamba_n_heads"): + mamba_n_heads = hf_config.mamba_n_heads + elif hasattr(hf_config, "n_mamba_heads"): + # zamba2 + mamba_n_heads = hf_config.n_mamba_heads + elif hasattr(hf_config, "mamba_num_heads"): + # nemotron-h + mamba_n_heads = hf_config.mamba_num_heads + else: + raise ValueError("Cannot find mamba n_heads in config.") + + if hasattr(hf_config, "mamba_d_head"): + mamba_d_head = hf_config.mamba_d_head + elif hasattr(hf_config, "mamba_headdim"): + # zamba2 + mamba_d_head = hf_config.mamba_headdim + elif hasattr(hf_config, "mamba_head_dim"): + # nemotron-h + mamba_d_head = hf_config.mamba_head_dim + else: + raise ValueError("Cannot find mamba d_head in config.") + + if hasattr(hf_config, "mamba_d_state"): + mamba_d_state = hf_config.mamba_d_state + elif hasattr(hf_config, "ssm_state_size"): + # nemotron-h + mamba_d_state = hf_config.ssm_state_size + else: + raise ValueError("Cannot find mamba d_state in config.") + + if hasattr(hf_config, "mamba_d_conv"): + mamba_d_conv = hf_config.mamba_d_conv + elif hasattr(hf_config, "conv_kernel"): + # nemotron-h + mamba_d_conv = hf_config.conv_kernel + else: + raise ValueError("Cannot find mamba d_conv in config.") + + world_size = parallel_config.tensor_parallel_size + hidden_size = hf_config.hidden_size + intermediate_size = mamba_expand * hidden_size + + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = ( + mamba_n_groups + + cls.extra_groups_for_head_shards(mamba_n_groups, world_size)) + + # - heads and n_groups are TP-ed + conv_dim = (intermediate_size + 2 * n_groups * mamba_d_state) + conv_state_shape = ( + divide(conv_dim, world_size), + mamba_d_conv - 1, + ) + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, d_head, d_state) = (128, 64, 128) + temporal_state_shape = ( + divide(mamba_n_heads, world_size), + mamba_d_head, + mamba_d_state, + ) + + return conv_state_shape, temporal_state_shape + + @classmethod + def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: + + if not envs.VLLM_USE_V1: + return + + cache_config = vllm_config.cache_config + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + + if cache_config.cache_dtype == "auto": + kv_cache_dtype = model_config.dtype + else: + kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + # get attention page size (for 1 token) + attn_page_size_1_token = FullAttentionSpec( + block_size=1, + num_kv_heads=model_config.get_num_kv_heads(parallel_config), + head_size=model_config.get_head_size(), + dtype=kv_cache_dtype, + use_mla=model_config.use_mla).page_size_bytes + + # get mamba page size + mamba_page_size = MambaSpec( + shapes=cls.get_mamba_cache_shape(vllm_config), + dtype=kv_cache_dtype, + block_size=model_config.max_model_len, + ).page_size_bytes + + # some attention backends (e.g. FA) only support setting + # block size to multiple of 16, so let's suggest a value + # that would work (note: FA is currently not compatible + # with mamba layers, use FlashInfer instead). + attn_block_size = 16 * cdiv(mamba_page_size, + 16 * attn_page_size_1_token) + + logger.info( + "Setting default attention block size to %d tokens " + "to ensure that attention page size is >= mamba page size.", + attn_block_size) + + mamba_page_size_padded = attn_block_size * attn_page_size_1_token + + mamba_padding_pct = 100 * (mamba_page_size_padded - + mamba_page_size) / mamba_page_size + + logger.info( + "Padding mamba page size by %.2f%% to ensure " + "that mamba page size and attention page size are " + "exactly equal.", mamba_padding_pct) + + # override attention block size if either (a) the + # user has not set it or (b) the user has set it + # too small. + if (cache_config.block_size is not None + and cache_config.block_size < attn_block_size): + cache_config.block_size = attn_block_size + + MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "GteModel": SnowflakeGteNewModelConfig, "GteNewModel": GteNewModelConfig, "NomicBertModel": NomicBertModelConfig, "Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig, "XLMRobertaModel": JinaRobertaModelConfig, + "FalconH1ForCausalLM": HybridAttentionMambaModelConfig, + "BambaForCausalLM": HybridAttentionMambaModelConfig, + "GraniteMoeHybridForCausalLM": HybridAttentionMambaModelConfig, + "NemotronHForCausalLM": HybridAttentionMambaModelConfig, + "Zamba2ForCausalLM": HybridAttentionMambaModelConfig, } diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 29b38fe41a34..8658d7d916f0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2615,107 +2615,10 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: format. Layers that do not need KV cache are not included. """ - attn_block_size = self.vllm_config.cache_config.block_size - - kv_cache_spec, attn_page_size, mamba_page_size = \ - self._get_kv_cache_spec_for_block_size( - attn_block_size=attn_block_size, - ) - - if (attn_page_size is not None and mamba_page_size is not None - and attn_page_size != mamba_page_size): - - kv_cache_spec = self._align_hybrid_page_size( - attn_page_size, - mamba_page_size, - attn_block_size, - ) - - return kv_cache_spec - - def _align_hybrid_page_size( - self, attn_page_size: int, mamba_page_size: int, - attn_block_size: int) -> dict[str, KVCacheSpec]: - """ - Ensures that the page size for attention layers and mamba layers - are equal. This is achieved by (1) ensuring that the attention - block size is set such that the page size is greater than or - equal to the mamba page size and then (2) padding the mamba page - size to ensure they are precisely equal. - - Args: - attn_page_size: Page size (bytes) for attention layers. - mamba_page_size: Page size (bytes) for mamba layers. - attn_block_size: Block size (tokens) for attention layers. - Returns: - KVCacheSpec: A dictionary mapping layer names to their KV cache - format. Layers that do not need KV cache are not included. - """ - - # check if we need to increase attention block size - if attn_page_size < mamba_page_size: - - # attention page size (for 1 token) - attn_page_size_1 = attn_page_size // attn_block_size - - # some attention backends (e.g. FA) only support setting - # block size to multiple of 16, so let's suggest a value - # that would work (note: FA is currently not compatible - # with mamba layers, use FlashInfer instead). - attn_block_size = 16 * cdiv(mamba_page_size, 16 * attn_page_size_1) - - logger.info( - "Increasing attention block size from %d to %d tokens " - "to ensure that attention page size is >= mamba page size.", - self.vllm_config.cache_config.block_size, attn_block_size) - - mamba_page_size_padded = attn_block_size * attn_page_size_1 - - mamba_padding_pct = 100 * (mamba_page_size_padded - - mamba_page_size) / mamba_page_size - - logger.info( - "Padding mamba page size by %.2f%% to ensure " - "that mamba page size and attention page size are " - "exactly equal.", mamba_padding_pct) - - kv_cache_spec, new_attn_page_size, new_mamba_page_size = \ - self._get_kv_cache_spec_for_block_size( - attn_block_size=attn_block_size, - mamba_page_size_padded=mamba_page_size_padded, - ) - - assert new_attn_page_size == new_mamba_page_size, ( - f"Attention page size ({new_attn_page_size}) does not equal " - "mamba page size ({new_mamba_page_size}) after alignment.") - - return kv_cache_spec - - def _get_kv_cache_spec_for_block_size( - self, - attn_block_size: int, - mamba_page_size_padded: Optional[int] = None, - ) -> tuple[dict[str, KVCacheSpec], Optional[int], Optional[int]]: - """ - Generates the KVCacheSpec by parsing the kv cache format from each - Attention module in the static forward context, assuming a given - block size (in tokens) for attention layers. - - Args: - attn_block_size: Block size (in tokens) for attention layers. - mamba_page_size_padded: Padded page size for mamba layers. - Returns: - KVCacheSpec: A dictionary mapping layer names to their KV cache - format. Layers that do not need KV cache are not included. - Optional[int]: Page size for attention layers. - Optional[int]: Page size for mamba layers. - """ - + block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla kv_cache_spec: dict[str, KVCacheSpec] = {} attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) - - attn_page_size = None for layer_name, attn_module in attn_layers.items(): if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None: @@ -2733,7 +2636,7 @@ def _get_kv_cache_spec_for_block_size( if attn_module.attn_type == AttentionType.DECODER: if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( - block_size=attn_block_size, + block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, @@ -2741,17 +2644,11 @@ def _get_kv_cache_spec_for_block_size( use_mla=use_mla) else: kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=attn_block_size, + block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, use_mla=use_mla) - page_size = kv_cache_spec[layer_name].page_size_bytes - if attn_page_size is None: - attn_page_size = page_size - else: - assert page_size == attn_page_size - elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache. @@ -2764,7 +2661,6 @@ def _get_kv_cache_spec_for_block_size( mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaMixer2) - mamba_page_size = None if len(mamba_layers) > 0: if self.vllm_config.speculative_config is not None: raise NotImplementedError( @@ -2777,6 +2673,10 @@ def _get_kv_cache_spec_for_block_size( "Prefix caching is not supported for Mamba yet.") max_model_len = self.vllm_config.model_config.max_model_len + page_size_padded = self._maybe_pad_mamba_page_size( + attn_layers, mamba_layers, kv_cache_spec, max_model_len, + block_size) + # Set block_size to max_model_len, so that mamba model will always # have only one block in the KV cache. for layer_name, mamba_module in mamba_layers.items(): @@ -2784,12 +2684,57 @@ def _get_kv_cache_spec_for_block_size( shapes=mamba_module.get_state_shape(), dtype=self.kv_cache_dtype, block_size=max_model_len, - page_size_padded=mamba_page_size_padded) + page_size_padded=page_size_padded) - page_size = kv_cache_spec[layer_name].page_size_bytes - if mamba_page_size is None: - mamba_page_size = page_size - else: - assert page_size == mamba_page_size + return kv_cache_spec + + def _maybe_pad_mamba_page_size( + self, + attn_layers: dict[str, Attention], + mamba_layers: dict[str, MambaMixer2], + kv_cache_spec: dict[str, KVCacheSpec], + max_model_len: int, + block_size: int, + ) -> Optional[int]: + """ + Ensure that page size of attention KV cache groups is greater than or + equal to the mamba KV cache groups. If not, we suggest to the user + how to set the attention block size to ensure that it is. + + If the attention page size is strictly greater than the mamba page size, + we pad the mamba page size to make them equal. + + Args: + attn_layers: Attention layers + mamba_layers: Mamba layers + kv_cache_spec: KV cache spec (populated with attention layers) + + Returns: + Optional[int]: Mamba page size with padding (None if no padding). + """ - return kv_cache_spec, attn_page_size, mamba_page_size + if len(attn_layers) == 0: + return None + + attn_layer_name = next(iter(attn_layers)) + attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes + mamba_layer_name = next(iter(mamba_layers)) + mamba_page_size = MambaSpec( + shapes=mamba_layers[mamba_layer_name].get_state_shape(), + dtype=self.kv_cache_dtype, + block_size=max_model_len).page_size_bytes + if attn_page_size < mamba_page_size: + # attention page size (for 16 tokens) + attn_page_size_16 = 16 * attn_page_size // block_size + # some attention backends (e.g. FA) only support setting + # block size to multiple of 16, so let's suggest a value + # that would work (note: FA is currently not compatible + # with mamba layers, use FlashInfer instead). + suggest_attn_block_size = 16 * cdiv(mamba_page_size, + attn_page_size_16) + raise ValueError( + "Attention block size should be increased to at least " + f"{suggest_attn_block_size} in order to match " + "the mamba page size") + + return attn_page_size From 1df73195b7737ab3e62efcbc71d6877295740135 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 8 Jul 2025 20:57:22 +0000 Subject: [PATCH 07/11] Fix logic slightly Signed-off-by: Thomas Parnell --- vllm/model_executor/models/config.py | 35 ++++++++++++++-------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index bd8f8ed5ec6f..cc81a4d53c63 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -339,27 +339,28 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: attn_block_size = 16 * cdiv(mamba_page_size, 16 * attn_page_size_1_token) - logger.info( - "Setting default attention block size to %d tokens " - "to ensure that attention page size is >= mamba page size.", - attn_block_size) - - mamba_page_size_padded = attn_block_size * attn_page_size_1_token - - mamba_padding_pct = 100 * (mamba_page_size_padded - - mamba_page_size) / mamba_page_size - - logger.info( - "Padding mamba page size by %.2f%% to ensure " - "that mamba page size and attention page size are " - "exactly equal.", mamba_padding_pct) - # override attention block size if either (a) the # user has not set it or (b) the user has set it # too small. - if (cache_config.block_size is not None - and cache_config.block_size < attn_block_size): + if (cache_config.block_size is None + or cache_config.block_size < attn_block_size): cache_config.block_size = attn_block_size + logger.info( + "Setting attention block size to %d tokens " + "to ensure that attention page size is >= mamba page size.", + attn_block_size) + + # mamba page size will be padded up to match attention page size + mamba_page_size_padded = \ + cache_config.block_size * attn_page_size_1_token + + if mamba_page_size_padded > mamba_page_size: + mamba_padding_pct = 100 * (mamba_page_size_padded - + mamba_page_size) / mamba_page_size + logger.info( + "Padding mamba page size by %.2f%% to ensure " + "that mamba page size and attention page size are " + "exactly equal.", mamba_padding_pct) MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { From d04dcfe968d9ffad077ba95b4b0003b69f8b6c19 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Wed, 9 Jul 2025 04:22:07 +0000 Subject: [PATCH 08/11] Cleanup Signed-off-by: Thomas Parnell --- vllm/config.py | 3 + vllm/model_executor/models/config.py | 161 +++++++++++++++------------ vllm/v1/worker/gpu_model_runner.py | 56 +--------- 3 files changed, 94 insertions(+), 126 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 90cf885a40d4..0bdb86f158af 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1553,6 +1553,9 @@ class CacheConfig: checkpoint if available. Otherwise, the scales will default to 1.0.""" cpu_kvcache_space_bytes: Optional[int] = None """(CPU backend only) CPU key-value cache space.""" + mamba_page_size_padded: Optional[int] = None + """ Optional override for mamba page size; used by hybrid mamaba/attention + models to ensure exact alignment with attention page size.""" # Will be set after profiling. num_gpu_blocks: Optional[int] = field(default=None, init=False) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index cc81a4d53c63..fd5855193ff2 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from copy import deepcopy +from dataclasses import dataclass from typing import TYPE_CHECKING import vllm.envs as envs @@ -10,6 +11,8 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec if TYPE_CHECKING: + from transformers.configuration_utils import PretrainedConfig + from vllm.config import VllmConfig logger = init_logger(__name__) @@ -209,6 +212,26 @@ def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int) -> int: # for n_groups == 1, this is exactly tp_size - n_groups return tp_size - ngroups + @dataclass + class MambaConfig: + expand: int + n_groups: int + n_heads: int + d_head: int + d_state: int + d_conv: int + + @classmethod + def parse_mamba_config(cls, config: "PretrainedConfig") -> MambaConfig: + return cls.MambaConfig( + expand=config.mamba_expand, + n_groups=config.mamba_n_groups, + n_heads=config.mamba_n_heads, + d_head=config.mamba_d_head, + d_state=config.mamba_d_state, + d_conv=config.mamba_d_conv, + ) + @classmethod def get_mamba_cache_shape( cls, vllm_config: "VllmConfig" @@ -216,94 +239,47 @@ def get_mamba_cache_shape( parallel_config = vllm_config.parallel_config hf_config = vllm_config.model_config.hf_config - - if hasattr(hf_config, "mamba_expand"): - mamba_expand = hf_config.mamba_expand - elif hasattr(hf_config, "expand"): - # nemotron-h - mamba_expand = hf_config.expand - else: - raise ValueError("Cannot find mamba_expand in config.") - - if hasattr(hf_config, "mamba_n_groups"): - mamba_n_groups = hf_config.mamba_n_groups - elif hasattr(hf_config, "mamba_ngroups"): - # zamba2 - mamba_n_groups = hf_config.mamba_ngroups - elif hasattr(hf_config, "n_groups"): - # nemotron-h - mamba_n_groups = hf_config.n_groups - else: - raise ValueError("Cannot find mamba n_groups in config.") - - if hasattr(hf_config, "mamba_n_heads"): - mamba_n_heads = hf_config.mamba_n_heads - elif hasattr(hf_config, "n_mamba_heads"): - # zamba2 - mamba_n_heads = hf_config.n_mamba_heads - elif hasattr(hf_config, "mamba_num_heads"): - # nemotron-h - mamba_n_heads = hf_config.mamba_num_heads - else: - raise ValueError("Cannot find mamba n_heads in config.") - - if hasattr(hf_config, "mamba_d_head"): - mamba_d_head = hf_config.mamba_d_head - elif hasattr(hf_config, "mamba_headdim"): - # zamba2 - mamba_d_head = hf_config.mamba_headdim - elif hasattr(hf_config, "mamba_head_dim"): - # nemotron-h - mamba_d_head = hf_config.mamba_head_dim - else: - raise ValueError("Cannot find mamba d_head in config.") - - if hasattr(hf_config, "mamba_d_state"): - mamba_d_state = hf_config.mamba_d_state - elif hasattr(hf_config, "ssm_state_size"): - # nemotron-h - mamba_d_state = hf_config.ssm_state_size - else: - raise ValueError("Cannot find mamba d_state in config.") - - if hasattr(hf_config, "mamba_d_conv"): - mamba_d_conv = hf_config.mamba_d_conv - elif hasattr(hf_config, "conv_kernel"): - # nemotron-h - mamba_d_conv = hf_config.conv_kernel - else: - raise ValueError("Cannot find mamba d_conv in config.") + mamba_config = cls.parse_mamba_config(hf_config) world_size = parallel_config.tensor_parallel_size hidden_size = hf_config.hidden_size - intermediate_size = mamba_expand * hidden_size + intermediate_size = mamba_config.expand * hidden_size # if n_groups is not divisible by world_size, need to extend the shards # to ensure all groups needed by a head is sharded along with it - n_groups = ( - mamba_n_groups + - cls.extra_groups_for_head_shards(mamba_n_groups, world_size)) + n_groups = (mamba_config.n_groups + cls.extra_groups_for_head_shards( + mamba_config.n_groups, world_size)) # - heads and n_groups are TP-ed - conv_dim = (intermediate_size + 2 * n_groups * mamba_d_state) + conv_dim = (intermediate_size + 2 * n_groups * mamba_config.d_state) conv_state_shape = ( divide(conv_dim, world_size), - mamba_d_conv - 1, + mamba_config.d_conv - 1, ) # These are not TP-ed as they depend on A, dt_bias, D # - they are typically small # e.g., (h_heads, d_head, d_state) = (128, 64, 128) temporal_state_shape = ( - divide(mamba_n_heads, world_size), - mamba_d_head, - mamba_d_state, + divide(mamba_config.n_heads, world_size), + mamba_config.d_head, + mamba_config.d_state, ) return conv_state_shape, temporal_state_shape @classmethod def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: + """ + Ensure that page size of attention layers is greater than or + equal to the mamba layers. If not, automatically set the attention + block size to ensure that it is. If the attention page size is + strictly greater than the mamba page size, we pad the mamba page size + to make them equal. + + Args: + vllm_config: vLLM Config + """ if not envs.VLLM_USE_V1: return @@ -350,12 +326,21 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: "to ensure that attention page size is >= mamba page size.", attn_block_size) - # mamba page size will be padded up to match attention page size - mamba_page_size_padded = \ + # compute new attention page size + attn_page_size = \ cache_config.block_size * attn_page_size_1_token - if mamba_page_size_padded > mamba_page_size: - mamba_padding_pct = 100 * (mamba_page_size_padded - + assert attn_page_size >= mamba_page_size + + if attn_page_size == mamba_page_size: + # don't need to pad mamba page size + return + + # pad mamba page size to exactly match attention + if (cache_config.mamba_page_size_padded is None + or cache_config.mamba_page_size_padded != attn_page_size): + cache_config.mamba_page_size_padded = (attn_page_size) + mamba_padding_pct = 100 * (attn_page_size - mamba_page_size) / mamba_page_size logger.info( "Padding mamba page size by %.2f%% to ensure " @@ -363,6 +348,38 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: "exactly equal.", mamba_padding_pct) +class NemotronHModelConfig(HybridAttentionMambaModelConfig): + + @classmethod + def parse_mamba_config( + cls, config: "PretrainedConfig" + ) -> HybridAttentionMambaModelConfig.MambaConfig: + return HybridAttentionMambaModelConfig.MambaConfig( + expand=config.expand, + n_groups=config.n_groups, + n_heads=config.mamba_num_heads, + d_head=config.mamba_head_dim, + d_state=config.ssm_state_size, + d_conv=config.conv_kernel, + ) + + +class Zamba2ModelConfig(HybridAttentionMambaModelConfig): + + @classmethod + def parse_mamba_config( + cls, config: "PretrainedConfig" + ) -> HybridAttentionMambaModelConfig.MambaConfig: + return HybridAttentionMambaModelConfig.MambaConfig( + expand=config.mamba_expand, + n_groups=config.mamba_ngroups, + n_heads=config.n_mamba_heads, + d_head=config.mamba_headdim, + d_state=config.mamba_d_state, + d_conv=config.mamba_d_conv, + ) + + MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "GteModel": SnowflakeGteNewModelConfig, "GteNewModel": GteNewModelConfig, @@ -372,6 +389,6 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: "FalconH1ForCausalLM": HybridAttentionMambaModelConfig, "BambaForCausalLM": HybridAttentionMambaModelConfig, "GraniteMoeHybridForCausalLM": HybridAttentionMambaModelConfig, - "NemotronHForCausalLM": HybridAttentionMambaModelConfig, - "Zamba2ForCausalLM": HybridAttentionMambaModelConfig, + "NemotronHForCausalLM": NemotronHModelConfig, + "Zamba2ForCausalLM": Zamba2ModelConfig, } diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8658d7d916f0..8d4e0e14f32f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2673,9 +2673,8 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: "Prefix caching is not supported for Mamba yet.") max_model_len = self.vllm_config.model_config.max_model_len - page_size_padded = self._maybe_pad_mamba_page_size( - attn_layers, mamba_layers, kv_cache_spec, max_model_len, - block_size) + page_size_padded = ( + self.vllm_config.cache_config.mamba_page_size_padded) # Set block_size to max_model_len, so that mamba model will always # have only one block in the KV cache. @@ -2687,54 +2686,3 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: page_size_padded=page_size_padded) return kv_cache_spec - - def _maybe_pad_mamba_page_size( - self, - attn_layers: dict[str, Attention], - mamba_layers: dict[str, MambaMixer2], - kv_cache_spec: dict[str, KVCacheSpec], - max_model_len: int, - block_size: int, - ) -> Optional[int]: - """ - Ensure that page size of attention KV cache groups is greater than or - equal to the mamba KV cache groups. If not, we suggest to the user - how to set the attention block size to ensure that it is. - - If the attention page size is strictly greater than the mamba page size, - we pad the mamba page size to make them equal. - - Args: - attn_layers: Attention layers - mamba_layers: Mamba layers - kv_cache_spec: KV cache spec (populated with attention layers) - - Returns: - Optional[int]: Mamba page size with padding (None if no padding). - """ - - if len(attn_layers) == 0: - return None - - attn_layer_name = next(iter(attn_layers)) - attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes - mamba_layer_name = next(iter(mamba_layers)) - mamba_page_size = MambaSpec( - shapes=mamba_layers[mamba_layer_name].get_state_shape(), - dtype=self.kv_cache_dtype, - block_size=max_model_len).page_size_bytes - if attn_page_size < mamba_page_size: - # attention page size (for 16 tokens) - attn_page_size_16 = 16 * attn_page_size // block_size - # some attention backends (e.g. FA) only support setting - # block size to multiple of 16, so let's suggest a value - # that would work (note: FA is currently not compatible - # with mamba layers, use FlashInfer instead). - suggest_attn_block_size = 16 * cdiv(mamba_page_size, - attn_page_size_16) - raise ValueError( - "Attention block size should be increased to at least " - f"{suggest_attn_block_size} in order to match " - "the mamba page size") - - return attn_page_size From 2ff9a0967da5391a81ac2e9a0d48fce0a9bc4ac7 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Wed, 9 Jul 2025 04:40:43 +0000 Subject: [PATCH 09/11] Remove unused import Signed-off-by: Thomas Parnell --- vllm/v1/worker/gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2f03716e3992..d9d5b0a63330 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -43,7 +43,7 @@ from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, LazyLoader, async_tensor_h2d, cdiv, + GiB_bytes, LazyLoader, async_tensor_h2d, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up) from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend From bcd9376aae0838cd9ca68eb1a361f64adab0e990 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Wed, 9 Jul 2025 07:14:18 +0000 Subject: [PATCH 10/11] Revert other changs; update docs Signed-off-by: Thomas Parnell --- docs/usage/v1_guide.md | 3 +- .../models/language/generation/test_hybrid.py | 16 +- vllm/config.py | 3 - vllm/model_executor/models/config.py | 194 ------------------ vllm/v1/worker/gpu_model_runner.py | 58 +++++- 5 files changed, 72 insertions(+), 202 deletions(-) diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index d7634223542d..459ea2d676c1 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -112,7 +112,8 @@ enforcing eager mode and disabling prefix caching in V1. Models that combine Mamba-2 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, `Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`). Please note that these models currently require enforcing eager mode, disabling prefix caching, and using the FlashInfer attention -backend in V1. +backend in V1. It is also necessary to pass a non-standard block size for attention layers (this is not possible +using the `vllm serve` CLI yet). #### Encoder-Decoder Models diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index eba14e64553e..ecaae3ec1fc4 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -61,6 +61,14 @@ "tiiuae/Falcon-H1-0.5B-Base", ] +ATTN_BLOCK_SIZES = { + "ibm-ai-platform/Bamba-9B-v1": 528, + "Zyphra/Zamba2-1.2B-instruct": 80, + "nvidia/Nemotron-H-8B-Base-8K": 528, + "ibm-granite/granite-4.0-tiny-preview": 400, + "tiiuae/Falcon-H1-0.5B-Base": 800, +} + # Avoid OOM MAX_NUM_SEQS = 4 @@ -97,6 +105,11 @@ def test_models( example_prompts, max_tokens, num_logprobs) if model in V1_SUPPORTED_MODELS: + if model in HYBRID_MODELS and model in ATTN_BLOCK_SIZES: + block_size = ATTN_BLOCK_SIZES[model] + else: + block_size = 16 + with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") if model in HYBRID_MODELS: @@ -105,7 +118,8 @@ def test_models( with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, enforce_eager=True, - enable_prefix_caching=False) as vllm_model: + enable_prefix_caching=False, + block_size=block_size) as vllm_model: vllm_v1_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) else: diff --git a/vllm/config.py b/vllm/config.py index 8ba6db1f3387..508e09174cc8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1553,9 +1553,6 @@ class CacheConfig: checkpoint if available. Otherwise, the scales will default to 1.0.""" cpu_kvcache_space_bytes: Optional[int] = None """(CPU backend only) CPU key-value cache space.""" - mamba_page_size_padded: Optional[int] = None - """ Optional override for mamba page size; used by hybrid mamaba/attention - models to ensure exact alignment with attention page size.""" # Will be set after profiling. num_gpu_blocks: Optional[int] = field(default=None, init=False) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index fd5855193ff2..552c4b074216 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -1,18 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from copy import deepcopy -from dataclasses import dataclass from typing import TYPE_CHECKING -import vllm.envs as envs -from vllm.distributed import divide from vllm.logger import init_logger -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv -from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec if TYPE_CHECKING: - from transformers.configuration_utils import PretrainedConfig - from vllm.config import VllmConfig logger = init_logger(__name__) @@ -198,197 +191,10 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: } -class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): - - @classmethod - def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int) -> int: - """Compute the increase in group numbers to account for - replication in order to accompany the head shards.""" - - # in the case ngoups % tp_size == 0, this will be zero - if ngroups % tp_size == 0: - return 0 - - # for n_groups == 1, this is exactly tp_size - n_groups - return tp_size - ngroups - - @dataclass - class MambaConfig: - expand: int - n_groups: int - n_heads: int - d_head: int - d_state: int - d_conv: int - - @classmethod - def parse_mamba_config(cls, config: "PretrainedConfig") -> MambaConfig: - return cls.MambaConfig( - expand=config.mamba_expand, - n_groups=config.mamba_n_groups, - n_heads=config.mamba_n_heads, - d_head=config.mamba_d_head, - d_state=config.mamba_d_state, - d_conv=config.mamba_d_conv, - ) - - @classmethod - def get_mamba_cache_shape( - cls, vllm_config: "VllmConfig" - ) -> tuple[tuple[int, int], tuple[int, int]]: - - parallel_config = vllm_config.parallel_config - hf_config = vllm_config.model_config.hf_config - mamba_config = cls.parse_mamba_config(hf_config) - - world_size = parallel_config.tensor_parallel_size - hidden_size = hf_config.hidden_size - intermediate_size = mamba_config.expand * hidden_size - - # if n_groups is not divisible by world_size, need to extend the shards - # to ensure all groups needed by a head is sharded along with it - n_groups = (mamba_config.n_groups + cls.extra_groups_for_head_shards( - mamba_config.n_groups, world_size)) - - # - heads and n_groups are TP-ed - conv_dim = (intermediate_size + 2 * n_groups * mamba_config.d_state) - conv_state_shape = ( - divide(conv_dim, world_size), - mamba_config.d_conv - 1, - ) - - # These are not TP-ed as they depend on A, dt_bias, D - # - they are typically small - # e.g., (h_heads, d_head, d_state) = (128, 64, 128) - temporal_state_shape = ( - divide(mamba_config.n_heads, world_size), - mamba_config.d_head, - mamba_config.d_state, - ) - - return conv_state_shape, temporal_state_shape - - @classmethod - def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: - """ - Ensure that page size of attention layers is greater than or - equal to the mamba layers. If not, automatically set the attention - block size to ensure that it is. If the attention page size is - strictly greater than the mamba page size, we pad the mamba page size - to make them equal. - - Args: - vllm_config: vLLM Config - """ - - if not envs.VLLM_USE_V1: - return - - cache_config = vllm_config.cache_config - model_config = vllm_config.model_config - parallel_config = vllm_config.parallel_config - - if cache_config.cache_dtype == "auto": - kv_cache_dtype = model_config.dtype - else: - kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] - - # get attention page size (for 1 token) - attn_page_size_1_token = FullAttentionSpec( - block_size=1, - num_kv_heads=model_config.get_num_kv_heads(parallel_config), - head_size=model_config.get_head_size(), - dtype=kv_cache_dtype, - use_mla=model_config.use_mla).page_size_bytes - - # get mamba page size - mamba_page_size = MambaSpec( - shapes=cls.get_mamba_cache_shape(vllm_config), - dtype=kv_cache_dtype, - block_size=model_config.max_model_len, - ).page_size_bytes - - # some attention backends (e.g. FA) only support setting - # block size to multiple of 16, so let's suggest a value - # that would work (note: FA is currently not compatible - # with mamba layers, use FlashInfer instead). - attn_block_size = 16 * cdiv(mamba_page_size, - 16 * attn_page_size_1_token) - - # override attention block size if either (a) the - # user has not set it or (b) the user has set it - # too small. - if (cache_config.block_size is None - or cache_config.block_size < attn_block_size): - cache_config.block_size = attn_block_size - logger.info( - "Setting attention block size to %d tokens " - "to ensure that attention page size is >= mamba page size.", - attn_block_size) - - # compute new attention page size - attn_page_size = \ - cache_config.block_size * attn_page_size_1_token - - assert attn_page_size >= mamba_page_size - - if attn_page_size == mamba_page_size: - # don't need to pad mamba page size - return - - # pad mamba page size to exactly match attention - if (cache_config.mamba_page_size_padded is None - or cache_config.mamba_page_size_padded != attn_page_size): - cache_config.mamba_page_size_padded = (attn_page_size) - mamba_padding_pct = 100 * (attn_page_size - - mamba_page_size) / mamba_page_size - logger.info( - "Padding mamba page size by %.2f%% to ensure " - "that mamba page size and attention page size are " - "exactly equal.", mamba_padding_pct) - - -class NemotronHModelConfig(HybridAttentionMambaModelConfig): - - @classmethod - def parse_mamba_config( - cls, config: "PretrainedConfig" - ) -> HybridAttentionMambaModelConfig.MambaConfig: - return HybridAttentionMambaModelConfig.MambaConfig( - expand=config.expand, - n_groups=config.n_groups, - n_heads=config.mamba_num_heads, - d_head=config.mamba_head_dim, - d_state=config.ssm_state_size, - d_conv=config.conv_kernel, - ) - - -class Zamba2ModelConfig(HybridAttentionMambaModelConfig): - - @classmethod - def parse_mamba_config( - cls, config: "PretrainedConfig" - ) -> HybridAttentionMambaModelConfig.MambaConfig: - return HybridAttentionMambaModelConfig.MambaConfig( - expand=config.mamba_expand, - n_groups=config.mamba_ngroups, - n_heads=config.n_mamba_heads, - d_head=config.mamba_headdim, - d_state=config.mamba_d_state, - d_conv=config.mamba_d_conv, - ) - - MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "GteModel": SnowflakeGteNewModelConfig, "GteNewModel": GteNewModelConfig, "NomicBertModel": NomicBertModelConfig, "Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig, "XLMRobertaModel": JinaRobertaModelConfig, - "FalconH1ForCausalLM": HybridAttentionMambaModelConfig, - "BambaForCausalLM": HybridAttentionMambaModelConfig, - "GraniteMoeHybridForCausalLM": HybridAttentionMambaModelConfig, - "NemotronHForCausalLM": NemotronHModelConfig, - "Zamba2ForCausalLM": Zamba2ModelConfig, } diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d9d5b0a63330..ef03626cf14d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -43,7 +43,7 @@ from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, LazyLoader, async_tensor_h2d, + GiB_bytes, LazyLoader, async_tensor_h2d, cdiv, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up) from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend @@ -2675,8 +2675,9 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: "Prefix caching is not supported for Mamba yet.") max_model_len = self.vllm_config.model_config.max_model_len - page_size_padded = ( - self.vllm_config.cache_config.mamba_page_size_padded) + page_size_padded = self._maybe_pad_mamba_page_size( + attn_layers, mamba_layers, kv_cache_spec, max_model_len, + block_size) # Set block_size to max_model_len, so that mamba model will always # have only one block in the KV cache. @@ -2688,3 +2689,54 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: page_size_padded=page_size_padded) return kv_cache_spec + + def _maybe_pad_mamba_page_size( + self, + attn_layers: dict[str, Attention], + mamba_layers: dict[str, MambaMixer2], + kv_cache_spec: dict[str, KVCacheSpec], + max_model_len: int, + block_size: int, + ) -> Optional[int]: + """ + Ensure that page size of attention KV cache groups is greater than or + equal to the mamba KV cache groups. If not, we suggest to the user + how to set the attention block size to ensure that it is. + + If the attention page size is strictly greater than the mamba page size, + we pad the mamba page size to make them equal. + + Args: + attn_layers: Attention layers + mamba_layers: Mamba layers + kv_cache_spec: KV cache spec (populated with attention layers) + + Returns: + Optional[int]: Mamba page size with padding (None if no padding). + """ + + if len(attn_layers) == 0: + return None + + attn_layer_name = next(iter(attn_layers)) + attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes + mamba_layer_name = next(iter(mamba_layers)) + mamba_page_size = MambaSpec( + shapes=mamba_layers[mamba_layer_name].get_state_shape(), + dtype=self.kv_cache_dtype, + block_size=max_model_len).page_size_bytes + if attn_page_size < mamba_page_size: + # attention page size (for 16 tokens) + attn_page_size_16 = 16 * attn_page_size // block_size + # some attention backends (e.g. FA) only support setting + # block size to multiple of 16, so let's suggest a value + # that would work (note: FA is currently not compatible + # with mamba layers, use FlashInfer instead). + suggest_attn_block_size = 16 * cdiv(mamba_page_size, + attn_page_size_16) + raise ValueError( + "Attention block size should be increased to at least " + f"{suggest_attn_block_size} in order to match " + "the mamba page size") + + return attn_page_size From 61ec90b55295c3883e73b4d139b97fec5496b265 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Wed, 9 Jul 2025 09:57:50 +0200 Subject: [PATCH 11/11] Apply suggestions from code review Unicode icon weirdness Co-authored-by: Cyrus Leung Signed-off-by: Thomas Parnell --- docs/models/supported_models.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 51abdc72b72c..e75d656af283 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -316,7 +316,7 @@ Specified using `--task generate`. | `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ | | `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅ | +| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ | | `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | | | `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | | | `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ | @@ -332,7 +332,7 @@ Specified using `--task generate`. | `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | ✅︎ | | `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | | ✅︎ | ✅︎ | -| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | ✅ | +| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | @@ -345,7 +345,7 @@ Specified using `--task generate`. | `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | | ✅︎ | ✅︎ | | `GraniteForCausalLM` | Granite 3.0, Granite 3.1, PowerLM | `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | ✅ | +| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ | | `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | | `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ | @@ -357,14 +357,14 @@ Specified using `--task generate`. | `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | | | `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | | -| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅ | +| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ | | `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ | | `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅ | +| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ | | `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | | ✅︎ | ✅︎ | | `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | | ✅︎ | ✅︎ | | `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ | @@ -389,7 +389,7 @@ Specified using `--task generate`. | `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`, etc. | | | | | `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | | -| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | ✅ | +| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | ✅︎ | !!! note Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.