-
-
Notifications
You must be signed in to change notification settings - Fork 8.9k
[V1] [Doc] Update V1 docs for Mamba models #20499
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
f9a5f4e
e618204
600ec11
ec6d840
a5c542f
84daa12
5ea6bed
1df7319
d04dcfe
11680bf
2ff9a09
bcd9376
61ec90b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,18 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
from copy import deepcopy | ||
from dataclasses import dataclass | ||
from typing import TYPE_CHECKING | ||
|
||
import vllm.envs as envs | ||
from vllm.distributed import divide | ||
from vllm.logger import init_logger | ||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv | ||
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec | ||
|
||
if TYPE_CHECKING: | ||
from transformers.configuration_utils import PretrainedConfig | ||
|
||
from vllm.config import VllmConfig | ||
|
||
logger = init_logger(__name__) | ||
|
@@ -191,10 +198,197 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: | |
} | ||
|
||
|
||
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): | ||
|
||
@classmethod | ||
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int) -> int: | ||
"""Compute the increase in group numbers to account for | ||
replication in order to accompany the head shards.""" | ||
|
||
# in the case ngoups % tp_size == 0, this will be zero | ||
if ngroups % tp_size == 0: | ||
return 0 | ||
|
||
# for n_groups == 1, this is exactly tp_size - n_groups | ||
return tp_size - ngroups | ||
|
||
@dataclass | ||
class MambaConfig: | ||
expand: int | ||
n_groups: int | ||
n_heads: int | ||
d_head: int | ||
d_state: int | ||
d_conv: int | ||
|
||
@classmethod | ||
def parse_mamba_config(cls, config: "PretrainedConfig") -> MambaConfig: | ||
return cls.MambaConfig( | ||
expand=config.mamba_expand, | ||
n_groups=config.mamba_n_groups, | ||
n_heads=config.mamba_n_heads, | ||
d_head=config.mamba_d_head, | ||
d_state=config.mamba_d_state, | ||
d_conv=config.mamba_d_conv, | ||
) | ||
|
||
@classmethod | ||
def get_mamba_cache_shape( | ||
cls, vllm_config: "VllmConfig" | ||
) -> tuple[tuple[int, int], tuple[int, int]]: | ||
|
||
parallel_config = vllm_config.parallel_config | ||
hf_config = vllm_config.model_config.hf_config | ||
mamba_config = cls.parse_mamba_config(hf_config) | ||
|
||
world_size = parallel_config.tensor_parallel_size | ||
hidden_size = hf_config.hidden_size | ||
intermediate_size = mamba_config.expand * hidden_size | ||
|
||
# if n_groups is not divisible by world_size, need to extend the shards | ||
# to ensure all groups needed by a head is sharded along with it | ||
n_groups = (mamba_config.n_groups + cls.extra_groups_for_head_shards( | ||
mamba_config.n_groups, world_size)) | ||
|
||
# - heads and n_groups are TP-ed | ||
conv_dim = (intermediate_size + 2 * n_groups * mamba_config.d_state) | ||
conv_state_shape = ( | ||
divide(conv_dim, world_size), | ||
mamba_config.d_conv - 1, | ||
) | ||
|
||
# These are not TP-ed as they depend on A, dt_bias, D | ||
# - they are typically small | ||
# e.g., (h_heads, d_head, d_state) = (128, 64, 128) | ||
temporal_state_shape = ( | ||
divide(mamba_config.n_heads, world_size), | ||
mamba_config.d_head, | ||
mamba_config.d_state, | ||
) | ||
|
||
return conv_state_shape, temporal_state_shape | ||
|
||
@classmethod | ||
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: | ||
""" | ||
Ensure that page size of attention layers is greater than or | ||
equal to the mamba layers. If not, automatically set the attention | ||
block size to ensure that it is. If the attention page size is | ||
strictly greater than the mamba page size, we pad the mamba page size | ||
to make them equal. | ||
|
||
Args: | ||
vllm_config: vLLM Config | ||
""" | ||
|
||
if not envs.VLLM_USE_V1: | ||
return | ||
|
||
cache_config = vllm_config.cache_config | ||
model_config = vllm_config.model_config | ||
parallel_config = vllm_config.parallel_config | ||
|
||
if cache_config.cache_dtype == "auto": | ||
kv_cache_dtype = model_config.dtype | ||
else: | ||
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] | ||
|
||
# get attention page size (for 1 token) | ||
attn_page_size_1_token = FullAttentionSpec( | ||
block_size=1, | ||
num_kv_heads=model_config.get_num_kv_heads(parallel_config), | ||
head_size=model_config.get_head_size(), | ||
dtype=kv_cache_dtype, | ||
use_mla=model_config.use_mla).page_size_bytes | ||
|
||
# get mamba page size | ||
mamba_page_size = MambaSpec( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As the logic of getting MambaSpec here is different from that in gpu_model_runner, can you add some check in gpu_model_runner.get_kv_cache_spec to verify that these two are the same. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure |
||
shapes=cls.get_mamba_cache_shape(vllm_config), | ||
dtype=kv_cache_dtype, | ||
block_size=model_config.max_model_len, | ||
).page_size_bytes | ||
|
||
# some attention backends (e.g. FA) only support setting | ||
# block size to multiple of 16, so let's suggest a value | ||
# that would work (note: FA is currently not compatible | ||
# with mamba layers, use FlashInfer instead). | ||
attn_block_size = 16 * cdiv(mamba_page_size, | ||
16 * attn_page_size_1_token) | ||
|
||
# override attention block size if either (a) the | ||
# user has not set it or (b) the user has set it | ||
# too small. | ||
if (cache_config.block_size is None | ||
or cache_config.block_size < attn_block_size): | ||
cache_config.block_size = attn_block_size | ||
logger.info( | ||
"Setting attention block size to %d tokens " | ||
"to ensure that attention page size is >= mamba page size.", | ||
attn_block_size) | ||
|
||
# compute new attention page size | ||
attn_page_size = \ | ||
cache_config.block_size * attn_page_size_1_token | ||
|
||
assert attn_page_size >= mamba_page_size | ||
|
||
if attn_page_size == mamba_page_size: | ||
# don't need to pad mamba page size | ||
return | ||
|
||
# pad mamba page size to exactly match attention | ||
if (cache_config.mamba_page_size_padded is None | ||
or cache_config.mamba_page_size_padded != attn_page_size): | ||
cache_config.mamba_page_size_padded = (attn_page_size) | ||
mamba_padding_pct = 100 * (attn_page_size - | ||
mamba_page_size) / mamba_page_size | ||
logger.info( | ||
"Padding mamba page size by %.2f%% to ensure " | ||
"that mamba page size and attention page size are " | ||
"exactly equal.", mamba_padding_pct) | ||
|
||
|
||
class NemotronHModelConfig(HybridAttentionMambaModelConfig): | ||
|
||
@classmethod | ||
def parse_mamba_config( | ||
cls, config: "PretrainedConfig" | ||
) -> HybridAttentionMambaModelConfig.MambaConfig: | ||
return HybridAttentionMambaModelConfig.MambaConfig( | ||
expand=config.expand, | ||
n_groups=config.n_groups, | ||
n_heads=config.mamba_num_heads, | ||
d_head=config.mamba_head_dim, | ||
d_state=config.ssm_state_size, | ||
d_conv=config.conv_kernel, | ||
) | ||
|
||
|
||
class Zamba2ModelConfig(HybridAttentionMambaModelConfig): | ||
|
||
@classmethod | ||
def parse_mamba_config( | ||
cls, config: "PretrainedConfig" | ||
) -> HybridAttentionMambaModelConfig.MambaConfig: | ||
return HybridAttentionMambaModelConfig.MambaConfig( | ||
expand=config.mamba_expand, | ||
n_groups=config.mamba_ngroups, | ||
n_heads=config.n_mamba_heads, | ||
d_head=config.mamba_headdim, | ||
d_state=config.mamba_d_state, | ||
d_conv=config.mamba_d_conv, | ||
) | ||
|
||
|
||
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { | ||
"GteModel": SnowflakeGteNewModelConfig, | ||
"GteNewModel": GteNewModelConfig, | ||
"NomicBertModel": NomicBertModelConfig, | ||
"Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig, | ||
"XLMRobertaModel": JinaRobertaModelConfig, | ||
"FalconH1ForCausalLM": HybridAttentionMambaModelConfig, | ||
"BambaForCausalLM": HybridAttentionMambaModelConfig, | ||
"GraniteMoeHybridForCausalLM": HybridAttentionMambaModelConfig, | ||
"NemotronHForCausalLM": NemotronHModelConfig, | ||
"Zamba2ForCausalLM": Zamba2ModelConfig, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't want to require each new model to update this page. What about letting get_mamba_cache_shape to be an interface that should be implemented by each hybrid model? The abstractions you made now can be some useful utility function to minimize code duplication when implementing this interface for each model. class IsHybrid(Protocol):
def get_mamba_cache_shape(cls, ...) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, was thinking of something similar. I will have a go at it and open a new PR. |
||
} |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you make the checkmark icon consistent with the other entries in this table? Same for the other tables
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The checkmark icon looks identical to me. Could you paste a screenshot of the difference you are seeing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's also visible in the git diff
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lol, wth. this is what git diff (and the table) shows for me:

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm let me just add suggestions in GitHub and you can commit the changes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
must be some unicode weirdness
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thanks for that