diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index e6dd6c35e64d..ecaae3ec1fc4 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -3,6 +3,7 @@ import pytest +from tests.models.registry import HF_EXAMPLE_MODELS from tests.utils import multi_gpu_test from vllm.engine.arg_utils import EngineArgs from vllm.sampling_params import SamplingParams @@ -19,31 +20,55 @@ SSM_MODELS = [ "state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev", - # TODO: Compare to a Mamba2 model. The HF transformers implementation of - # Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test - # doesn't compare vLLM output with HF output. - # See https://github.com/huggingface/transformers/pull/35943 "mistralai/Mamba-Codestral-7B-v0.1", ] HYBRID_MODELS = [ "ai21labs/Jamba-tiny-dev", - # NOTE: Currently the test failes due to HF transformers issue fixed in: - # https://github.com/huggingface/transformers/pull/39033 - # We will enable vLLM test for Granite after next HF transformers release. - # "ibm-granite/granite-4.0-tiny-preview", # NOTE: Running Plamo2 in transformers implementation requires to install # causal-conv1d package, which is not listed as a test dependency as it's # not compatible with pip-compile. "pfnet/plamo-2-1b", "Zyphra/Zamba2-1.2B-instruct", "hmellor/tiny-random-BambaForCausalLM", + "ibm-ai-platform/Bamba-9B-v1", + "nvidia/Nemotron-H-8B-Base-8K", + "ibm-granite/granite-4.0-tiny-preview", + "tiiuae/Falcon-H1-0.5B-Base", +] + +HF_UNSUPPORTED_MODELS = [ + # The HF transformers implementation of + # Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test + # doesn't compare vLLM output with HF output. + # See https://github.com/huggingface/transformers/pull/35943 + "mistralai/Mamba-Codestral-7B-v0.1", + # Note: I'm not seeing the same output from vLLM V0 vs. HF transformers + # for Nemotron-H-8B; currently only compare vLLM V0 vs. vLLM V1 + "nvidia/Nemotron-H-8B-Base-8K", + # NOTE: Currently the test fails due to HF transformers issue fixed in: + # https://github.com/huggingface/transformers/pull/39033 + # We will enable vLLM test for Granite after next HF transformers release. + "ibm-granite/granite-4.0-tiny-preview", ] V1_SUPPORTED_MODELS = [ "mistralai/Mamba-Codestral-7B-v0.1", + "ibm-ai-platform/Bamba-9B-v1", + "Zyphra/Zamba2-1.2B-instruct", + "nvidia/Nemotron-H-8B-Base-8K", + "ibm-granite/granite-4.0-tiny-preview", + "tiiuae/Falcon-H1-0.5B-Base", ] +ATTN_BLOCK_SIZES = { + "ibm-ai-platform/Bamba-9B-v1": 528, + "Zyphra/Zamba2-1.2B-instruct": 80, + "nvidia/Nemotron-H-8B-Base-8K": 528, + "ibm-granite/granite-4.0-tiny-preview": 400, + "tiiuae/Falcon-H1-0.5B-Base": 800, +} + # Avoid OOM MAX_NUM_SEQS = 4 @@ -60,8 +85,16 @@ def test_models( max_tokens: int, num_logprobs: int, ) -> None: + + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + with hf_runner(model) as hf_model: - if model != "mistralai/Mamba-Codestral-7B-v0.1": + if model not in HF_UNSUPPORTED_MODELS: hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) else: @@ -72,12 +105,21 @@ def test_models( example_prompts, max_tokens, num_logprobs) if model in V1_SUPPORTED_MODELS: + if model in HYBRID_MODELS and model in ATTN_BLOCK_SIZES: + block_size = ATTN_BLOCK_SIZES[model] + else: + block_size = 16 + with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") + if model in HYBRID_MODELS: + # required due to reorder_batch behaviour + m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, enforce_eager=True, - enable_prefix_caching=False) as vllm_model: + enable_prefix_caching=False, + block_size=block_size) as vllm_model: vllm_v1_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) else: @@ -111,6 +153,14 @@ def test_batching( max_tokens: int, num_logprobs: int, ) -> None: + + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + for_loop_outputs = [] with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: for prompt in example_prompts: diff --git a/tests/models/registry.py b/tests/models/registry.py index 704aa76b84d4..728c18643a00 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -169,7 +169,7 @@ def check_available_online( "ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501 "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501 "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), - "FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-1.5B-Instruct", + "FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base", min_transformers_version="4.53"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index d640d7dc49d1..7a7ba346a719 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -13,7 +13,6 @@ "openai/whisper-large-v3", # transcription "facebook/bart-large-cnn", # encoder decoder "state-spaces/mamba-130m-hf", # mamba1 - "hmellor/tiny-random-BambaForCausalLM", # hybrid "BAAI/bge-m3", # embedding ] diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index ccfb278cdff6..3f67fc35afdf 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -108,7 +108,7 @@ def _selective_scan_update_kernel( # is the same as the batch id. if HAS_STATE_BATCH_INDICES: state_batch_indices_ptr += pid_b - state_batch_idx = tl.load(state_batch_indices_ptr) + state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64) state_ptr += (state_batch_idx * stride_state_batch + pid_h * stride_state_head) else: diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 29e0e2a2edb1..d743c52074c6 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -9,6 +9,7 @@ from torch import nn from transformers import BambaConfig +from vllm import envs from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size @@ -36,7 +37,7 @@ from vllm.utils import LayerBlockType from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, - SupportsQuant, SupportsV0Only) + SupportsQuant) from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -97,7 +98,9 @@ def __init__(self, head_dim=config.mamba_d_head, rms_norm_eps=config.rms_norm_eps, activation=config.hidden_act, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.mixer", + chunk_size=config.mamba_chunk_size) self.feed_forward = BambaMLP(config, quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, @@ -313,10 +316,14 @@ def forward( attn_metadata = get_forward_context().attn_metadata - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.mamba_chunk_size, - attn_metadata=attn_metadata, - ) + if not envs.VLLM_USE_V1: + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.mamba_chunk_size, + attn_metadata=attn_metadata, + ) + else: + # v1 get mamba2_metadata from forward_context + mamba2_metadata = None if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -337,7 +344,8 @@ def forward( num_attn += 1 layer_mamba_cache_params = None - if isinstance(layer, BambaMixerDecoderLayer): + if isinstance(layer, + BambaMixerDecoderLayer) and mamba_cache_params: layer_mamba_cache_params = mamba_cache_params.at_layer_idx( i - num_attn) @@ -411,7 +419,7 @@ def load_weights(self, weights: Iterable[tuple[str, class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybrid, SupportsV0Only, SupportsQuant): + IsHybrid, SupportsQuant): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -475,15 +483,22 @@ def forward(self, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - if self.mamba_cache is None: - num_mamba_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) + mamba_cache_params = None + if not envs.VLLM_USE_V1: + if self.mamba_cache is None: + num_mamba_layers = \ + self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, + LayerBlockType.mamba + ) + + self.mamba_cache = MambaCacheManager( + self.vllm_config, self.lm_head.weight.dtype, + num_mamba_layers, *self._get_mamba_cache_shape()) + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, - *self._get_mamba_cache_shape()) - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) hidden_states = self.model(input_ids, positions, mamba_cache_params, intermediate_tensors, inputs_embeds) diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 28f257eabed0..a76e1f256e04 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -8,6 +8,7 @@ from torch import nn from transformers import FalconH1Config +from vllm import envs from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size @@ -33,8 +34,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, - SupportsV0Only) +from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP from .utils import (PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -85,6 +85,7 @@ def __init__( config: FalconH1Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -107,6 +108,8 @@ def __init__( activation=config.hidden_act, quant_config=quant_config, use_rms_norm=config.mamba_rms_norm, + prefix=f"{prefix}.mixer", + chunk_size=config.mamba_chunk_size, ) # n_groups is overridden later by `MambaMixer2` self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state @@ -316,6 +319,7 @@ def __init__( prefix: str = "", ) -> None: super().__init__() + # Instantiate the attention branch self.self_attn = FalconH1AttentionDecoderLayer( config=config, @@ -323,11 +327,18 @@ def __init__( quant_config=quant_config, prefix=prefix, ) + + # In V1 all attention/ssm layers must have + # different index in prefix + ssm_layer_idx = config.num_hidden_layers + layer_idx + ssm_prefix = prefix.split(".")[0] + f".{ssm_layer_idx}" + # Instantiate the SSM branch self.mamba = FalconH1SSMDecoderLayer( config=config, cache_config=cache_config, quant_config=quant_config, + prefix=ssm_prefix, ) self.ssm_out_multiplier = config.ssm_out_multiplier self.ssm_in_multiplier = config.ssm_in_multiplier @@ -452,10 +463,16 @@ def forward( # proper continuous batching computation including # chunked prefill attn_metadata = get_forward_context().attn_metadata - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.mamba_chunk_size, - attn_metadata=attn_metadata, - ) + + if not envs.VLLM_USE_V1: + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.mamba_chunk_size, + attn_metadata=attn_metadata, + ) + else: + # v1 get mamba2_metadata from forward_context + mamba2_metadata = None + if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds * self.embedding_multiplier @@ -468,7 +485,9 @@ def forward( for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i) + layer_mamba_cache_params = None + if mamba_cache_params: + layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i) hidden_states = layer( positions=positions, hidden_states=hidden_states, @@ -484,7 +503,7 @@ def forward( class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybrid, SupportsV0Only): + IsHybrid): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], @@ -558,15 +577,19 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ): - if self.mamba_cache is None: - self.mamba_cache = MambaCacheManager( - self.vllm_config, - self.lm_head.weight.dtype - if hasattr(self.lm_head, 'weight') else torch.bfloat16, - self.config.num_hidden_layers, - *self._get_mamba_cache_shape(), - ) - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + + mamba_cache_params = None + if not envs.VLLM_USE_V1: + if self.mamba_cache is None: + self.mamba_cache = MambaCacheManager( + self.vllm_config, + self.lm_head.weight.dtype if hasattr( + self.lm_head, 'weight') else torch.bfloat16, + self.config.num_hidden_layers, + *self._get_mamba_cache_shape(), + ) + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + hidden_states = self.model( input_ids, positions, diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 33e8626209d5..676ef24fc4da 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -9,6 +9,7 @@ from torch import nn from transformers import GraniteMoeHybridConfig +from vllm import envs from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size @@ -36,7 +37,7 @@ from .granitemoe import GraniteMoeMoE from .granitemoeshared import GraniteMoeSharedMLP from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, - SupportsQuant, SupportsV0Only) + SupportsQuant) from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -67,7 +68,9 @@ def __init__(self, head_dim=config.mamba_d_head, rms_norm_eps=config.rms_norm_eps, activation=config.hidden_act, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.mixer", + chunk_size=config.mamba_chunk_size) self.block_sparse_moe = None if getattr(config, "num_local_experts", 0) > 0: @@ -361,10 +364,15 @@ def forward( ) -> torch.Tensor: attn_metadata = get_forward_context().attn_metadata - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.mamba_chunk_size, - attn_metadata=attn_metadata, - ) + + if not envs.VLLM_USE_V1: + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.mamba_chunk_size, + attn_metadata=attn_metadata, + ) + else: + # v1 get mamba2_metadata from forward_context + mamba2_metadata = None if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -386,7 +394,9 @@ def forward( num_attn += 1 layer_mamba_cache_params = None - if isinstance(layer, GraniteMoeHybridMambaDecoderLayer): + if isinstance( + layer, + GraniteMoeHybridMambaDecoderLayer) and mamba_cache_params: layer_mamba_cache_params = mamba_cache_params.at_layer_idx( i - num_attn) @@ -501,8 +511,7 @@ def _load_expert(n, p, name, shard_id, expert_id): class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, - SupportsPP, IsHybrid, SupportsV0Only, - SupportsQuant): + SupportsPP, IsHybrid, SupportsQuant): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -571,14 +580,20 @@ def forward(self, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - if self.mamba_cache is None: - num_mamba_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.model_config.dtype, num_mamba_layers, - *self._get_mamba_cache_shape()) - - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + + mamba_cache_params = None + if not envs.VLLM_USE_V1: + if self.mamba_cache is None: + num_mamba_layers = ( + self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, + LayerBlockType.mamba)) + self.mamba_cache = MambaCacheManager( + self.vllm_config, self.model_config.dtype, + num_mamba_layers, *self._get_mamba_cache_shape()) + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + hidden_states = self.model(input_ids, positions, mamba_cache_params, intermediate_tensors, inputs_embeds) diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 3424efa80d48..5d51b01df9db 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -23,6 +23,7 @@ import torch from torch import nn +from vllm import envs from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size @@ -44,8 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, - SupportsQuant, - SupportsV0Only) + SupportsQuant) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.models.utils import ( @@ -153,6 +153,8 @@ def __init__( rms_norm_eps=config.rms_norm_eps, activation=config.mamba_hidden_act, quant_config=quant_config, + prefix=f"{prefix}.mixer", + chunk_size=config.chunk_size, ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -348,10 +350,14 @@ def forward( attn_metadata = get_forward_context().attn_metadata - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.chunk_size, - attn_metadata=attn_metadata, - ) + if not envs.VLLM_USE_V1: + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.chunk_size, + attn_metadata=attn_metadata, + ) + else: + # v1 get mamba2_metadata from forward_context + mamba2_metadata = None if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -369,7 +375,8 @@ def forward( for i in range(len(self.layers)): layer = self.layers[i] layer_mamba_cache_params = None - if isinstance(layer, NemotronHMambaDecoderLayer): + if isinstance(layer, + NemotronHMambaDecoderLayer) and mamba_cache_params: layer_mamba_cache_params = mamba_cache_params.at_layer_idx( i - num_non_mamba_layers) else: @@ -437,7 +444,7 @@ def load_weights(self, weights: Iterable[tuple[str, class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybrid, SupportsV0Only, SupportsQuant): + IsHybrid, SupportsQuant): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -499,15 +506,23 @@ def forward(self, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - if self.mamba_cache is None: - num_mamba_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) + mamba_cache_params = None + if not envs.VLLM_USE_V1: + if self.mamba_cache is None: + + num_mamba_layers = \ + self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, + LayerBlockType.mamba + ) + + self.mamba_cache = MambaCacheManager( + self.vllm_config, self.lm_head.weight.dtype, + num_mamba_layers, *self._get_mamba_cache_shape()) + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, - *self._get_mamba_cache_shape()) - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) hidden_states = self.model(input_ids, positions, mamba_cache_params, intermediate_tensors, inputs_embeds) diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index a4f97c774f70..54c80cfa5922 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -15,6 +15,7 @@ from torch import nn from transformers import Zamba2Config +from vllm import envs from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size @@ -41,7 +42,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import HasInnerState, IsHybrid, SupportsV0Only +from .interfaces import HasInnerState, IsHybrid from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix @@ -58,6 +59,7 @@ def __init__( rank: int, output_dim: Union[int, list[int]], quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): """Initialize the attention layer. @@ -283,6 +285,7 @@ def __init__( bare_block_idx: int, num_hybrid_layers: dict[int, int], quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: """Initialize the MLP layer. @@ -471,11 +474,10 @@ class Zamba2MambaDecoderLayer(nn.Module): computation depending on configuration. """ - def __init__( - self, - config: Zamba2Config, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: + def __init__(self, + config: Zamba2Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: """Initialize the Mamba decoder layer. Args: @@ -486,20 +488,21 @@ def __init__( # Initialize Mamba mixer with expanded intermediate size intermediate_size = config.mamba_expand * config.hidden_size - self.mamba = MambaMixer2( - hidden_size=config.hidden_size, - ssm_state_size=config.mamba_d_state, - conv_kernel_size=config.mamba_d_conv, - intermediate_size=intermediate_size, - use_conv_bias=config.use_conv_bias, - use_bias=config.add_bias_linear, - n_groups=config.mamba_ngroups, - num_heads=config.n_mamba_heads, - head_dim=intermediate_size // config.n_mamba_heads, - rms_norm_eps=config.rms_norm_eps, - activation="silu", - quant_config=quant_config, - ) + self.mamba = MambaMixer2(hidden_size=config.hidden_size, + ssm_state_size=config.mamba_d_state, + conv_kernel_size=config.mamba_d_conv, + intermediate_size=intermediate_size, + use_conv_bias=config.use_conv_bias, + use_bias=config.add_bias_linear, + n_groups=config.mamba_ngroups, + num_heads=config.n_mamba_heads, + head_dim=intermediate_size // + config.n_mamba_heads, + rms_norm_eps=config.rms_norm_eps, + activation="silu", + quant_config=quant_config, + prefix=f"{prefix}.mixer", + chunk_size=config.chunk_size) # Input normalization self.input_layernorm = RMSNorm(config.hidden_size, @@ -573,6 +576,7 @@ def __init__( config: Zamba2Config, block_idx: int, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: """Initialize the hybrid layer. @@ -589,7 +593,8 @@ def __init__( bias=False, quant_config=quant_config) self.mamba_decoder = Zamba2MambaDecoderLayer(config, - quant_config=quant_config) + quant_config=quant_config, + prefix=prefix) def forward( self, @@ -699,14 +704,23 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: # Initialize layers according to block type configuration layers = [] for layer_idx, layer_type in enumerate(config.layers_block_type): + # tdoublep: avoid layers getting same index + # somewhat hacky but correct (I think) + prefix = str(len(layer2block_map) + layer_idx) if layer_type == "hybrid": block = next(blocks) block_idx = layer2block_map[layer_idx] layers.append( - Zamba2HybridLayer(block, config, block_idx, quant_config)) + Zamba2HybridLayer(block, + config, + block_idx, + quant_config, + prefix=prefix)) else: layers.append( - Zamba2MambaDecoderLayer(config, quant_config=quant_config)) + Zamba2MambaDecoderLayer(config, + quant_config=quant_config, + prefix=prefix)) self.layers = nn.ModuleList(layers) # Final layer normalization @@ -751,19 +765,30 @@ def forward( attn_metadata = get_forward_context().attn_metadata - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.chunk_size, - attn_metadata=attn_metadata, - ) + if not envs.VLLM_USE_V1: + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.chunk_size, + attn_metadata=attn_metadata, + ) + else: + # v1 get mamba2_metadata from forward_context + mamba2_metadata = None # Process through layers original_hidden_states = torch.clone(hidden_states) for layer_idx, layer in enumerate(self.layers): + + layer_mamba_cache_params = None + if (isinstance(layer, (Zamba2HybridLayer, Zamba2MambaDecoderLayer)) + and mamba_cache_params): + layer_mamba_cache_params = mamba_cache_params.at_layer_idx( + layer_idx) + layer_outputs = layer( hidden_states, original_hidden_states=original_hidden_states, positions=positions, - mamba_cache_params=mamba_cache_params.at_layer_idx(layer_idx), + mamba_cache_params=layer_mamba_cache_params, mamba2_metadata=mamba2_metadata, ) hidden_states = layer_outputs @@ -803,7 +828,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): +class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): """Zamba2 model with causal language modeling head. This class wraps the core Zamba2 model and adds: @@ -897,14 +922,16 @@ def forward(self, Output hidden states """ # Initialize Mamba cache if needed - if self.mamba_cache is None: - num_mamba_layers = self.config.num_hidden_layers - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, - *self._get_mamba_cache_shape()) - - # Get cache parameters for current run - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + mamba_cache_params = None + if not envs.VLLM_USE_V1: + if self.mamba_cache is None: + num_mamba_layers = self.config.num_hidden_layers + self.mamba_cache = MambaCacheManager( + self.vllm_config, self.lm_head.weight.dtype, + num_mamba_layers, *self._get_mamba_cache_shape()) + + # Get cache parameters for current run + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) # Forward pass through model hidden_states = self.model( diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index b88a5990ca92..38de00625e3f 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -27,6 +27,7 @@ def __init__( ): self.kv_cache_config = kv_cache_config self.max_model_len = max_model_len + self.enable_caching = enable_caching self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching, enable_kv_cache_events) @@ -268,9 +269,13 @@ def verify_and_split_kv_cache_groups(self) -> None: self.full_attention_block_size = self.full_attention_spec.block_size self.other_block_size = self.other_spec.block_size - assert self.other_block_size % self.full_attention_block_size == 0, ( - "KVCacheCoordinator assumes the block_size of full attention " - "layers is divisible by other layers now.") + + if self.enable_caching: + # this requirement is only needed for the prefix caching logic + divisible = self.other_block_size % self.full_attention_block_size + assert divisible == 0, ( + "KVCacheCoordinator assumes the block_size of full " + "attention layers is divisible by other layers now.") if max(self.full_attention_group_ids) < min(self.other_group_ids): self.full_attn_first = True diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 08bb0efb2f3d..6937455e7d85 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -84,12 +84,15 @@ def __init__( self.log_stats = log_stats # FIXME: make prefix cache stats conditional on log_stats self.prefix_cache_stats = PrefixCacheStats() if log_stats else None - assert len( - set(g.kv_cache_spec.block_size - for g in kv_cache_config.kv_cache_groups) - ) == 1, "Only one block size is supported for now" - self.block_size = kv_cache_config.kv_cache_groups[ - 0].kv_cache_spec.block_size + + self.block_size: Optional[int] = None + if self.enable_caching: + assert len( + set(g.kv_cache_spec.block_size + for g in kv_cache_config.kv_cache_groups) + ) == 1, "Only one block size is supported for now" + self.block_size = kv_cache_config.kv_cache_groups[ + 0].kv_cache_spec.block_size self.coordinator = get_kv_cache_coordinator( kv_cache_config=kv_cache_config, @@ -154,6 +157,7 @@ def get_computed_blocks(self, # if the scheduler has tried to schedule the request before. block_hashes = self.req_to_block_hashes[request.request_id] if not block_hashes: + assert self.block_size is not None block_hashes = hash_request_tokens(self.caching_hash_fn, self.block_size, request) self.req_to_block_hashes[request.request_id] = block_hashes diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 9489bcf433fd..2fbcb569e3d5 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -864,9 +864,11 @@ def _get_kv_cache_config_uniform_page_size( kv_cache_groups=kv_cache_groups, ) + min_block_size = min( + [group.kv_cache_spec.block_size for group in kv_cache_groups]) + # Print the KV cache size and maximum concurrency. - num_tokens = num_blocks // len( - grouped_layers) * vllm_config.cache_config.block_size + num_tokens = num_blocks // len(grouped_layers) * min_block_size num_tokens_str = f"{num_tokens:,}" logger.info("GPU KV cache size: %s tokens", num_tokens_str) max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index c48775adc9b8..43456a987def 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -159,6 +159,7 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: class MambaSpec(KVCacheSpec): shapes: tuple[tuple[int, ...], ...] dtype: torch.dtype + page_size_padded: Optional[int] = None def __post_init__(self): self.num_elements = sum(prod(shape) for shape in self.shapes) @@ -169,7 +170,11 @@ def type_id(self) -> str: @property def page_size_bytes(self) -> int: - return self.num_elements * get_dtype_size(self.dtype) + page_size = self.num_elements * get_dtype_size(self.dtype) + if self.page_size_padded is not None: + assert self.page_size_padded >= page_size + return self.page_size_padded + return page_size def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: # We allocate 1 block for each request now, so max_memory_usage_bytes is diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4786d047acb5..57d0c7b50ff5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -334,6 +334,9 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: # the same order of requests. We ensure this by only allowing the first # group to reorder the batch and asserting that all other groups do not # reorder the batch. + # TODO(tdoublep): make this more flexible so that any group can + # re-order the batch (not only the first). + # TODO(tdoublep): verify this during engine init instead of at runtime for i in range(1, len(self.kv_cache_config.kv_cache_groups)): batch_reordered = self.attn_metadata_builders[i].reorder_batch( self.input_batch, scheduler_output) @@ -2449,6 +2452,7 @@ def _reshape_kv_cache_tensors( corresponding memory buffer for KV cache. """ kv_caches: dict[str, torch.Tensor] = {} + has_attn, has_mamba = False, False for i, kv_cache_group_spec in enumerate( kv_cache_config.kv_cache_groups): kv_cache_spec = kv_cache_group_spec.kv_cache_spec @@ -2458,6 +2462,7 @@ def _reshape_kv_cache_tensors( num_blocks = (raw_tensor.numel() // kv_cache_spec.page_size_bytes) if isinstance(kv_cache_spec, AttentionSpec): + has_attn = True kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) @@ -2486,25 +2491,67 @@ def _reshape_kv_cache_tensors( layer_name].view(dtype).view(kv_cache_shape).permute( *inv_order) elif isinstance(kv_cache_spec, MambaSpec): + has_mamba = True raw_tensor = kv_cache_raw_tensors[layer_name] dtype = kv_cache_spec.dtype + num_element_per_page = (kv_cache_spec.page_size_bytes // + get_dtype_size(dtype)) state_tensors = [] - start_pos = 0 + storage_offset = 0 for shape in kv_cache_spec.shapes: target_shape = (num_blocks, *shape) - size_in_bytes = np.prod(shape) * get_dtype_size( - dtype) * num_blocks - tensor = raw_tensor[start_pos:start_pos + - size_in_bytes] - tensor = tensor.view(dtype).view(target_shape) + stride = torch.empty(target_shape).stride() + target_stride = (num_element_per_page, *stride[1:]) + tensor = torch.as_strided( + raw_tensor.view(dtype), + size=target_shape, + stride=target_stride, + storage_offset=storage_offset, + ) state_tensors.append(tensor) - start_pos += size_in_bytes - assert start_pos == raw_tensor.numel() - kv_caches[layer_name] = tuple(state_tensors) + storage_offset += stride[0] + + kv_caches[layer_name] = state_tensors else: raise NotImplementedError + + if has_attn and has_mamba: + self._verify_hybrid_attention_mamba_layout(kv_cache_config, + kv_cache_raw_tensors) + return kv_caches + def _verify_hybrid_attention_mamba_layout( + self, kv_cache_config: KVCacheConfig, + kv_cache_raw_tensors: dict[str, torch.Tensor]) -> None: + """ + Verify that the KV cache memory layout is compatible for + models with both attention and mamba KV cache groups. + + Args: + kv_cache_config: The KV cache config + kv_cache_raw_tensors: The KV cache buffer of each layer. + """ + + for i, kv_cache_group_spec in enumerate( + kv_cache_config.kv_cache_groups): + kv_cache_spec = kv_cache_group_spec.kv_cache_spec + for layer_name in kv_cache_group_spec.layer_names: + raw_tensor = kv_cache_raw_tensors[layer_name] + num_blocks = (raw_tensor.numel() // + kv_cache_spec.page_size_bytes) + if isinstance(kv_cache_spec, AttentionSpec): + kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + if kv_cache_shape[0] != num_blocks or kv_cache_shape[ + 1] != 2: + raise ValueError( + "Hybrid models in V1 require an attention " + "backend with kv_cache_shape=" + "(num_blocks, 2, ...). Please try setting " + "VLLM_ATTENTION_BACKEND=FLASHINFER") + def initialize_kv_cache_tensors( self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: """ @@ -2623,11 +2670,69 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: raise NotImplementedError( "Prefix caching is not supported for Mamba yet.") max_model_len = self.vllm_config.model_config.max_model_len + + page_size_padded = self._maybe_pad_mamba_page_size( + attn_layers, mamba_layers, kv_cache_spec, max_model_len, + block_size) + # Set block_size to max_model_len, so that mamba model will always # have only one block in the KV cache. for layer_name, mamba_module in mamba_layers.items(): kv_cache_spec[layer_name] = MambaSpec( shapes=mamba_module.get_state_shape(), dtype=self.kv_cache_dtype, - block_size=max_model_len) + block_size=max_model_len, + page_size_padded=page_size_padded) + return kv_cache_spec + + def _maybe_pad_mamba_page_size( + self, + attn_layers: dict[str, Attention], + mamba_layers: dict[str, MambaMixer2], + kv_cache_spec: dict[str, KVCacheSpec], + max_model_len: int, + block_size: int, + ) -> Optional[int]: + """ + Ensure that page size of attention KV cache groups is greater than or + equal to the mamba KV cache groups. If not, we suggest to the user + how to set the attention block size to ensure that it is. + + If the attention page size is strictly greater than the mamba page size, + we pad the mamba page size to make them equal. + + Args: + attn_layers: Attention layers + mamba_layers: Mamba layers + kv_cache_spec: KV cache spec (populated with attention layers) + + Returns: + Optional[int]: Mamba page size with padding (None if no padding). + """ + + if len(attn_layers) == 0: + return None + + attn_layer_name = next(iter(attn_layers)) + attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes + mamba_layer_name = next(iter(mamba_layers)) + mamba_page_size = MambaSpec( + shapes=mamba_layers[mamba_layer_name].get_state_shape(), + dtype=self.kv_cache_dtype, + block_size=max_model_len).page_size_bytes + if attn_page_size < mamba_page_size: + # attention page size (for 16 tokens) + attn_page_size_16 = 16 * attn_page_size // block_size + # some attention backends (e.g. FA) only support setting + # block size to multiple of 16, so let's suggest a value + # that would work (note: FA is currently not compatible + # with mamba layers, use FlashInfer instead). + suggest_attn_block_size = 16 * cdiv(mamba_page_size, + attn_page_size_16) + raise ValueError( + "Attention block size should be increased to at least " + f"{suggest_attn_block_size} in order to match " + "the mamba page size") + + return attn_page_size