Skip to content

Commit bcd9376

Browse files
committed
Revert other changs; update docs
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
1 parent 2ff9a09 commit bcd9376

File tree

5 files changed

+72
-202
lines changed

5 files changed

+72
-202
lines changed

docs/usage/v1_guide.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ enforcing eager mode and disabling prefix caching in V1.
112112
Models that combine Mamba-2 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
113113
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`). Please note that
114114
these models currently require enforcing eager mode, disabling prefix caching, and using the FlashInfer attention
115-
backend in V1.
115+
backend in V1. It is also necessary to pass a non-standard block size for attention layers (this is not possible
116+
using the `vllm serve` CLI yet).
116117

117118
#### Encoder-Decoder Models
118119

tests/models/language/generation/test_hybrid.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@
6161
"tiiuae/Falcon-H1-0.5B-Base",
6262
]
6363

64+
ATTN_BLOCK_SIZES = {
65+
"ibm-ai-platform/Bamba-9B-v1": 528,
66+
"Zyphra/Zamba2-1.2B-instruct": 80,
67+
"nvidia/Nemotron-H-8B-Base-8K": 528,
68+
"ibm-granite/granite-4.0-tiny-preview": 400,
69+
"tiiuae/Falcon-H1-0.5B-Base": 800,
70+
}
71+
6472
# Avoid OOM
6573
MAX_NUM_SEQS = 4
6674

@@ -97,6 +105,11 @@ def test_models(
97105
example_prompts, max_tokens, num_logprobs)
98106

99107
if model in V1_SUPPORTED_MODELS:
108+
if model in HYBRID_MODELS and model in ATTN_BLOCK_SIZES:
109+
block_size = ATTN_BLOCK_SIZES[model]
110+
else:
111+
block_size = 16
112+
100113
with monkeypatch.context() as m:
101114
m.setenv("VLLM_USE_V1", "1")
102115
if model in HYBRID_MODELS:
@@ -105,7 +118,8 @@ def test_models(
105118
with vllm_runner(model,
106119
max_num_seqs=MAX_NUM_SEQS,
107120
enforce_eager=True,
108-
enable_prefix_caching=False) as vllm_model:
121+
enable_prefix_caching=False,
122+
block_size=block_size) as vllm_model:
109123
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
110124
example_prompts, max_tokens, num_logprobs)
111125
else:

vllm/config.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1553,9 +1553,6 @@ class CacheConfig:
15531553
checkpoint if available. Otherwise, the scales will default to 1.0."""
15541554
cpu_kvcache_space_bytes: Optional[int] = None
15551555
"""(CPU backend only) CPU key-value cache space."""
1556-
mamba_page_size_padded: Optional[int] = None
1557-
""" Optional override for mamba page size; used by hybrid mamaba/attention
1558-
models to ensure exact alignment with attention page size."""
15591556

15601557
# Will be set after profiling.
15611558
num_gpu_blocks: Optional[int] = field(default=None, init=False)

vllm/model_executor/models/config.py

Lines changed: 0 additions & 194 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,11 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from copy import deepcopy
4-
from dataclasses import dataclass
54
from typing import TYPE_CHECKING
65

7-
import vllm.envs as envs
8-
from vllm.distributed import divide
96
from vllm.logger import init_logger
10-
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
11-
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
127

138
if TYPE_CHECKING:
14-
from transformers.configuration_utils import PretrainedConfig
15-
169
from vllm.config import VllmConfig
1710

1811
logger = init_logger(__name__)
@@ -198,197 +191,10 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None:
198191
}
199192

200193

201-
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
202-
203-
@classmethod
204-
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int) -> int:
205-
"""Compute the increase in group numbers to account for
206-
replication in order to accompany the head shards."""
207-
208-
# in the case ngoups % tp_size == 0, this will be zero
209-
if ngroups % tp_size == 0:
210-
return 0
211-
212-
# for n_groups == 1, this is exactly tp_size - n_groups
213-
return tp_size - ngroups
214-
215-
@dataclass
216-
class MambaConfig:
217-
expand: int
218-
n_groups: int
219-
n_heads: int
220-
d_head: int
221-
d_state: int
222-
d_conv: int
223-
224-
@classmethod
225-
def parse_mamba_config(cls, config: "PretrainedConfig") -> MambaConfig:
226-
return cls.MambaConfig(
227-
expand=config.mamba_expand,
228-
n_groups=config.mamba_n_groups,
229-
n_heads=config.mamba_n_heads,
230-
d_head=config.mamba_d_head,
231-
d_state=config.mamba_d_state,
232-
d_conv=config.mamba_d_conv,
233-
)
234-
235-
@classmethod
236-
def get_mamba_cache_shape(
237-
cls, vllm_config: "VllmConfig"
238-
) -> tuple[tuple[int, int], tuple[int, int]]:
239-
240-
parallel_config = vllm_config.parallel_config
241-
hf_config = vllm_config.model_config.hf_config
242-
mamba_config = cls.parse_mamba_config(hf_config)
243-
244-
world_size = parallel_config.tensor_parallel_size
245-
hidden_size = hf_config.hidden_size
246-
intermediate_size = mamba_config.expand * hidden_size
247-
248-
# if n_groups is not divisible by world_size, need to extend the shards
249-
# to ensure all groups needed by a head is sharded along with it
250-
n_groups = (mamba_config.n_groups + cls.extra_groups_for_head_shards(
251-
mamba_config.n_groups, world_size))
252-
253-
# - heads and n_groups are TP-ed
254-
conv_dim = (intermediate_size + 2 * n_groups * mamba_config.d_state)
255-
conv_state_shape = (
256-
divide(conv_dim, world_size),
257-
mamba_config.d_conv - 1,
258-
)
259-
260-
# These are not TP-ed as they depend on A, dt_bias, D
261-
# - they are typically small
262-
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
263-
temporal_state_shape = (
264-
divide(mamba_config.n_heads, world_size),
265-
mamba_config.d_head,
266-
mamba_config.d_state,
267-
)
268-
269-
return conv_state_shape, temporal_state_shape
270-
271-
@classmethod
272-
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
273-
"""
274-
Ensure that page size of attention layers is greater than or
275-
equal to the mamba layers. If not, automatically set the attention
276-
block size to ensure that it is. If the attention page size is
277-
strictly greater than the mamba page size, we pad the mamba page size
278-
to make them equal.
279-
280-
Args:
281-
vllm_config: vLLM Config
282-
"""
283-
284-
if not envs.VLLM_USE_V1:
285-
return
286-
287-
cache_config = vllm_config.cache_config
288-
model_config = vllm_config.model_config
289-
parallel_config = vllm_config.parallel_config
290-
291-
if cache_config.cache_dtype == "auto":
292-
kv_cache_dtype = model_config.dtype
293-
else:
294-
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
295-
296-
# get attention page size (for 1 token)
297-
attn_page_size_1_token = FullAttentionSpec(
298-
block_size=1,
299-
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
300-
head_size=model_config.get_head_size(),
301-
dtype=kv_cache_dtype,
302-
use_mla=model_config.use_mla).page_size_bytes
303-
304-
# get mamba page size
305-
mamba_page_size = MambaSpec(
306-
shapes=cls.get_mamba_cache_shape(vllm_config),
307-
dtype=kv_cache_dtype,
308-
block_size=model_config.max_model_len,
309-
).page_size_bytes
310-
311-
# some attention backends (e.g. FA) only support setting
312-
# block size to multiple of 16, so let's suggest a value
313-
# that would work (note: FA is currently not compatible
314-
# with mamba layers, use FlashInfer instead).
315-
attn_block_size = 16 * cdiv(mamba_page_size,
316-
16 * attn_page_size_1_token)
317-
318-
# override attention block size if either (a) the
319-
# user has not set it or (b) the user has set it
320-
# too small.
321-
if (cache_config.block_size is None
322-
or cache_config.block_size < attn_block_size):
323-
cache_config.block_size = attn_block_size
324-
logger.info(
325-
"Setting attention block size to %d tokens "
326-
"to ensure that attention page size is >= mamba page size.",
327-
attn_block_size)
328-
329-
# compute new attention page size
330-
attn_page_size = \
331-
cache_config.block_size * attn_page_size_1_token
332-
333-
assert attn_page_size >= mamba_page_size
334-
335-
if attn_page_size == mamba_page_size:
336-
# don't need to pad mamba page size
337-
return
338-
339-
# pad mamba page size to exactly match attention
340-
if (cache_config.mamba_page_size_padded is None
341-
or cache_config.mamba_page_size_padded != attn_page_size):
342-
cache_config.mamba_page_size_padded = (attn_page_size)
343-
mamba_padding_pct = 100 * (attn_page_size -
344-
mamba_page_size) / mamba_page_size
345-
logger.info(
346-
"Padding mamba page size by %.2f%% to ensure "
347-
"that mamba page size and attention page size are "
348-
"exactly equal.", mamba_padding_pct)
349-
350-
351-
class NemotronHModelConfig(HybridAttentionMambaModelConfig):
352-
353-
@classmethod
354-
def parse_mamba_config(
355-
cls, config: "PretrainedConfig"
356-
) -> HybridAttentionMambaModelConfig.MambaConfig:
357-
return HybridAttentionMambaModelConfig.MambaConfig(
358-
expand=config.expand,
359-
n_groups=config.n_groups,
360-
n_heads=config.mamba_num_heads,
361-
d_head=config.mamba_head_dim,
362-
d_state=config.ssm_state_size,
363-
d_conv=config.conv_kernel,
364-
)
365-
366-
367-
class Zamba2ModelConfig(HybridAttentionMambaModelConfig):
368-
369-
@classmethod
370-
def parse_mamba_config(
371-
cls, config: "PretrainedConfig"
372-
) -> HybridAttentionMambaModelConfig.MambaConfig:
373-
return HybridAttentionMambaModelConfig.MambaConfig(
374-
expand=config.mamba_expand,
375-
n_groups=config.mamba_ngroups,
376-
n_heads=config.n_mamba_heads,
377-
d_head=config.mamba_headdim,
378-
d_state=config.mamba_d_state,
379-
d_conv=config.mamba_d_conv,
380-
)
381-
382-
383194
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
384195
"GteModel": SnowflakeGteNewModelConfig,
385196
"GteNewModel": GteNewModelConfig,
386197
"NomicBertModel": NomicBertModelConfig,
387198
"Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
388199
"XLMRobertaModel": JinaRobertaModelConfig,
389-
"FalconH1ForCausalLM": HybridAttentionMambaModelConfig,
390-
"BambaForCausalLM": HybridAttentionMambaModelConfig,
391-
"GraniteMoeHybridForCausalLM": HybridAttentionMambaModelConfig,
392-
"NemotronHForCausalLM": NemotronHModelConfig,
393-
"Zamba2ForCausalLM": Zamba2ModelConfig,
394200
}

vllm/v1/worker/gpu_model_runner.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from vllm.sampling_params import SamplingType
4444
from vllm.sequence import IntermediateTensors
4545
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
46-
GiB_bytes, LazyLoader, async_tensor_h2d,
46+
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
4747
check_use_alibi, get_dtype_size,
4848
is_pin_memory_available, round_up)
4949
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
@@ -2675,8 +2675,9 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
26752675
"Prefix caching is not supported for Mamba yet.")
26762676
max_model_len = self.vllm_config.model_config.max_model_len
26772677

2678-
page_size_padded = (
2679-
self.vllm_config.cache_config.mamba_page_size_padded)
2678+
page_size_padded = self._maybe_pad_mamba_page_size(
2679+
attn_layers, mamba_layers, kv_cache_spec, max_model_len,
2680+
block_size)
26802681

26812682
# Set block_size to max_model_len, so that mamba model will always
26822683
# have only one block in the KV cache.
@@ -2688,3 +2689,54 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
26882689
page_size_padded=page_size_padded)
26892690

26902691
return kv_cache_spec
2692+
2693+
def _maybe_pad_mamba_page_size(
2694+
self,
2695+
attn_layers: dict[str, Attention],
2696+
mamba_layers: dict[str, MambaMixer2],
2697+
kv_cache_spec: dict[str, KVCacheSpec],
2698+
max_model_len: int,
2699+
block_size: int,
2700+
) -> Optional[int]:
2701+
"""
2702+
Ensure that page size of attention KV cache groups is greater than or
2703+
equal to the mamba KV cache groups. If not, we suggest to the user
2704+
how to set the attention block size to ensure that it is.
2705+
2706+
If the attention page size is strictly greater than the mamba page size,
2707+
we pad the mamba page size to make them equal.
2708+
2709+
Args:
2710+
attn_layers: Attention layers
2711+
mamba_layers: Mamba layers
2712+
kv_cache_spec: KV cache spec (populated with attention layers)
2713+
2714+
Returns:
2715+
Optional[int]: Mamba page size with padding (None if no padding).
2716+
"""
2717+
2718+
if len(attn_layers) == 0:
2719+
return None
2720+
2721+
attn_layer_name = next(iter(attn_layers))
2722+
attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes
2723+
mamba_layer_name = next(iter(mamba_layers))
2724+
mamba_page_size = MambaSpec(
2725+
shapes=mamba_layers[mamba_layer_name].get_state_shape(),
2726+
dtype=self.kv_cache_dtype,
2727+
block_size=max_model_len).page_size_bytes
2728+
if attn_page_size < mamba_page_size:
2729+
# attention page size (for 16 tokens)
2730+
attn_page_size_16 = 16 * attn_page_size // block_size
2731+
# some attention backends (e.g. FA) only support setting
2732+
# block size to multiple of 16, so let's suggest a value
2733+
# that would work (note: FA is currently not compatible
2734+
# with mamba layers, use FlashInfer instead).
2735+
suggest_attn_block_size = 16 * cdiv(mamba_page_size,
2736+
attn_page_size_16)
2737+
raise ValueError(
2738+
"Attention block size should be increased to at least "
2739+
f"{suggest_attn_block_size} in order to match "
2740+
"the mamba page size")
2741+
2742+
return attn_page_size

0 commit comments

Comments
 (0)