From de4e3a203014f149da0f4421f7172d719d2bbfec Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 23 Jun 2025 19:14:46 +0000 Subject: [PATCH 01/28] working change Signed-off-by: Thomas Parnell --- vllm/model_executor/models/bamba.py | 41 +++++++++++++------- vllm/v1/core/kv_cache_coordinator.py | 13 +++++-- vllm/v1/core/kv_cache_manager.py | 10 ++--- vllm/v1/core/kv_cache_utils.py | 20 +++++++++- vllm/v1/core/single_type_kv_cache_manager.py | 4 ++ vllm/v1/kv_cache_interface.py | 7 +++- vllm/v1/worker/gpu_model_runner.py | 6 ++- 7 files changed, 76 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 29e0e2a2edb1..26f537616ed2 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -9,6 +9,7 @@ from torch import nn from transformers import BambaConfig +from vllm import envs from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size @@ -97,7 +98,9 @@ def __init__(self, head_dim=config.mamba_d_head, rms_norm_eps=config.rms_norm_eps, activation=config.hidden_act, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.mixer", + chunk_size=config.mamba_chunk_size) self.feed_forward = BambaMLP(config, quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, @@ -311,12 +314,18 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: + print("self.config.mamba_chunk_size: ", self.config.mamba_chunk_size) + attn_metadata = get_forward_context().attn_metadata - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.mamba_chunk_size, - attn_metadata=attn_metadata, - ) + if not envs.VLLM_USE_V1: + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.mamba_chunk_size, + attn_metadata=attn_metadata, + ) + else: + # v1 get mamba2_metadata from forward_context + mamba2_metadata = None if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -337,7 +346,7 @@ def forward( num_attn += 1 layer_mamba_cache_params = None - if isinstance(layer, BambaMixerDecoderLayer): + if isinstance(layer, BambaMixerDecoderLayer) and mamba_cache_params: layer_mamba_cache_params = mamba_cache_params.at_layer_idx( i - num_attn) @@ -411,7 +420,7 @@ def load_weights(self, weights: Iterable[tuple[str, class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybrid, SupportsV0Only, SupportsQuant): + IsHybrid, SupportsQuant): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -475,15 +484,19 @@ def forward(self, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - if self.mamba_cache is None: + if not envs.VLLM_USE_V1: + if self.mamba_cache is None: + + num_mamba_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) - num_mamba_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) + self.mamba_cache = MambaCacheManager( + self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, + *self._get_mamba_cache_shape()) + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + else: + mamba_cache_params = None - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, - *self._get_mamba_cache_shape()) - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) hidden_states = self.model(input_ids, positions, mamba_cache_params, intermediate_tensors, inputs_embeds) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 5620d9bee7a3..fad1fc132d9c 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -41,6 +41,8 @@ def __init__( ) for i, kv_cache_group in enumerate( self.kv_cache_config.kv_cache_groups)) + print("single_type_managers: ", self.single_type_managers) + def get_num_blocks_to_allocate( self, request_id: str, num_tokens: int, new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> int: @@ -57,10 +59,14 @@ def get_num_blocks_to_allocate( Returns: The number of blocks. """ + print("request_id: ", request_id) + print("num_tokens: ", num_tokens) + print("new_computed_blocks: ", new_computed_blocks) num_blocks_to_allocate = 0 for i, manager in enumerate(self.single_type_managers): num_blocks_to_allocate += manager.get_num_blocks_to_allocate( request_id, num_tokens, new_computed_blocks[i]) + print("i: ", i, " manager: ", manager, " num_blocks_to_allocate: ", num_blocks_to_allocate) return num_blocks_to_allocate def save_new_computed_blocks( @@ -267,9 +273,10 @@ def verify_and_split_kv_cache_groups(self) -> None: self.full_attention_block_size = self.full_attention_spec.block_size self.other_block_size = self.other_spec.block_size - assert self.other_block_size % self.full_attention_block_size == 0, ( - "KVCacheCoordinator assumes the block_size of full attention " - "layers is divisible by other layers now.") + # think this is only needed for prefix caching + #assert self.other_block_size % self.full_attention_block_size == 0, ( + # "KVCacheCoordinator assumes the block_size of full attention " + # "layers is divisible by other layers now.") if max(self.full_attention_group_ids) < min(self.other_group_ids): self.full_attn_first = True diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 08bb0efb2f3d..94ee82b926ca 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -84,13 +84,13 @@ def __init__( self.log_stats = log_stats # FIXME: make prefix cache stats conditional on log_stats self.prefix_cache_stats = PrefixCacheStats() if log_stats else None - assert len( - set(g.kv_cache_spec.block_size - for g in kv_cache_config.kv_cache_groups) - ) == 1, "Only one block size is supported for now" + #assert len( + # set(g.kv_cache_spec.block_size + # for g in kv_cache_config.kv_cache_groups) + #) == 1, "Only one block size is supported for now" self.block_size = kv_cache_config.kv_cache_groups[ 0].kv_cache_spec.block_size - + print("self.block_size: ", self.block_size) self.coordinator = get_kv_cache_coordinator( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 9489bcf433fd..d6c9de6e55d1 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -605,6 +605,7 @@ def create_kv_cache_group_specs( merged_layer_spec = layer_specs[0].merge(layer_specs) kv_cache_groups.append( KVCacheGroupSpec(layer_names_one_group, merged_layer_spec)) + print("len(kv_cache_groups): ", len(kv_cache_groups)) return kv_cache_groups @@ -692,7 +693,6 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, page_size = get_uniform_page_size(kv_cache_spec) num_blocks = get_num_blocks(vllm_config, len(kv_cache_spec), available_memory, page_size) - per_layer_size = page_size * num_blocks # All layers have the same KV cache spec, so we create one kv cache group # for all layers. @@ -734,6 +734,8 @@ def is_kv_cache_page_size_uniform( """ page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} + print("page_sizes: ", page_sizes) + return len(page_sizes) == 1 @@ -811,6 +813,10 @@ def _get_kv_cache_config_uniform_page_size( for layer_name, layer_spec in kv_cache_spec.items(): same_type_layers[layer_spec.type_id].append(layer_name) + + for layer_type in same_type_layers.keys(): + print(layer_type, len(same_type_layers[layer_type])) + # Split each group into smaller groups, to make the number of layers in each # group identical. Add padding to the last group of each type if necessary. # E.g., (full.0, full.1), (sw.0, sw.1, sw.2) @@ -824,6 +830,9 @@ def _get_kv_cache_config_uniform_page_size( # strategy if we want to support more complex patterns (e.g., 20 full + 30 # sw, where the group size should be 10). group_size = min([len(layers) for layers in same_type_layers.values()]) + + print("group_size: ", group_size) + grouped_layers = [] for layers in same_type_layers.values(): num_padding_layers = group_size - len(layers) % group_size @@ -846,17 +855,26 @@ def _get_kv_cache_config_uniform_page_size( # full.0, sw.0, sw.2: share a Tensor with size=available_memory//2 # full.1, sw.1: share another Tensor with size=available_memory//2 page_size = get_uniform_page_size(kv_cache_spec) + print("page_size: ", page_size) num_blocks = get_num_blocks(vllm_config, group_size, available_memory, page_size) + print("num_blocks: ", num_blocks) per_memory_pool_size = page_size * num_blocks + print("per_memory_pool_size: ", per_memory_pool_size) + kv_cache_tensors = [] for i in range(group_size): + print("i: ", i) shared_by = [] for j in range(len(kv_cache_groups)): + print("j: ", j) + print("grouped_layers[j]: ", grouped_layers[j]) if i < len(grouped_layers[j]): shared_by.append(grouped_layers[j][i]) + print("shared_by: ", shared_by) kv_cache_tensors.append( KVCacheTensor(size=per_memory_pool_size, shared_by=shared_by)) + print("kv_cache_tensors: ", kv_cache_tensors) kv_cache_config = KVCacheConfig( num_blocks=num_blocks, diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 5b4718038076..4ccba65ebfbe 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -121,6 +121,10 @@ def allocate_new_blocks(self, request_id: str, req_blocks = self.req_to_blocks[request_id] num_required_blocks = cdiv(num_tokens, self.block_size) num_new_blocks = num_required_blocks - len(req_blocks) + print("[allocate_new_blocks] num_tokens: ", num_tokens) + print("[allocate_new_blocks] block_size: ", self.block_size) + print("[allocate_new_blocks] num_required_blocks: ", num_required_blocks) + if num_new_blocks <= 0: return [] else: diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index c48775adc9b8..34b7d4fc3272 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -169,7 +169,12 @@ def type_id(self) -> str: @property def page_size_bytes(self) -> int: - return self.num_elements * get_dtype_size(self.dtype) + real_page_size = self.num_elements * get_dtype_size(self.dtype) + hack_page_size = 528 * 4096 + print("real_page_size: ", real_page_size) + print("hack_page_size: ", hack_page_size) + assert hack_page_size >= real_page_size + return hack_page_size def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: # We allocate 1 block for each request now, so max_memory_usage_bytes is diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 330366006118..6534c87fb96e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2355,6 +2355,8 @@ def _reshape_kv_cache_tensors( assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 num_blocks = (raw_tensor.numel() // kv_cache_spec.page_size_bytes) + print("layer_name: ", layer_name) + print("num_blocks: ", num_blocks) if isinstance(kv_cache_spec, AttentionSpec): kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, @@ -2388,16 +2390,18 @@ def _reshape_kv_cache_tensors( dtype = kv_cache_spec.dtype state_tensors = [] start_pos = 0 + print("kv_cache_spec.shapes: ", kv_cache_spec.shapes) for shape in kv_cache_spec.shapes: target_shape = (num_blocks, *shape) size_in_bytes = np.prod(shape) * get_dtype_size( dtype) * num_blocks + print("size_in_bytes: ", size_in_bytes) tensor = raw_tensor[start_pos:start_pos + size_in_bytes] tensor = tensor.view(dtype).view(target_shape) state_tensors.append(tensor) start_pos += size_in_bytes - assert start_pos == raw_tensor.numel() + #assert start_pos == raw_tensor.numel() kv_caches[layer_name] = tuple(state_tensors) else: raise NotImplementedError From 617cd26f4d50eda71b7b1d1cb0ff7104f2f659d5 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 24 Jun 2025 08:51:17 +0000 Subject: [PATCH 02/28] working changes Signed-off-by: Thomas Parnell --- .../models/language/generation/test_hybrid.py | 62 +++++++++++++------ vllm/model_executor/models/bamba.py | 8 +-- vllm/v1/core/kv_cache_coordinator.py | 11 ++-- vllm/v1/core/kv_cache_manager.py | 18 ++++-- vllm/v1/core/kv_cache_utils.py | 7 ++- vllm/v1/core/sched/scheduler.py | 13 ++++ vllm/v1/kv_cache_interface.py | 13 ++-- vllm/v1/worker/gpu_model_runner.py | 34 +++++++++- 8 files changed, 125 insertions(+), 41 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 90c4cd968e7a..6d343be0deb6 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -14,36 +14,44 @@ # The rest of the models will only be tested by test_models SSM_MODELS = [ - "state-spaces/mamba-130m-hf", - "tiiuae/falcon-mamba-tiny-dev", + #"state-spaces/mamba-130m-hf", + #"tiiuae/falcon-mamba-tiny-dev", # TODO: Compare to a Mamba2 model. The HF transformers implementation of # Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test # doesn't compare vLLM output with HF output. # See https://github.com/huggingface/transformers/pull/35943 - "mistralai/Mamba-Codestral-7B-v0.1", + #"mistralai/Mamba-Codestral-7B-v0.1", ] HYBRID_MODELS = [ - "ai21labs/Jamba-tiny-dev", + #"ai21labs/Jamba-tiny-dev", # NOTE: ibm-granite/granite-4.0-tiny-preview are skipped currently as # it is not yet available in huggingface transformers # "ibm-granite/granite-4.0-tiny-preview", # NOTE: Running Plamo2 in transformers implementation requires to install # causal-conv1d package, which is not listed as a test dependency as it's # not compatible with pip-compile. - "pfnet/plamo-2-1b", - "Zyphra/Zamba2-1.2B-instruct", - "hmellor/tiny-random-BambaForCausalLM", + #"pfnet/plamo-2-1b", + #"Zyphra/Zamba2-1.2B-instruct", + #"hmellor/tiny-random-BambaForCausalLM", + "/net/storage149/autofs/css22/nmg/models/hf/ibm-ai-platform/Bamba-9B-v1/main/", ] V1_SUPPORTED_MODELS = [ "mistralai/Mamba-Codestral-7B-v0.1", + #"hmellor/tiny-random-BambaForCausalLM", + "/net/storage149/autofs/css22/nmg/models/hf/ibm-ai-platform/Bamba-9B-v1/main/" ] + +ATTN_BLOCK_SIZES = { + "hmellor/tiny-random-BambaForCausalLM": 48, + "/net/storage149/autofs/css22/nmg/models/hf/ibm-ai-platform/Bamba-9B-v1/main/": 528 +} + # Avoid OOM MAX_NUM_SEQS = 4 - @pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) @@ -56,6 +64,9 @@ def test_models( max_tokens: int, num_logprobs: int, ) -> None: + + #example_prompts = ["Hello World "] + with hf_runner(model) as hf_model: if model != "mistralai/Mamba-Codestral-7B-v0.1": hf_outputs = hf_model.generate_greedy_logprobs_limit( @@ -63,22 +74,33 @@ def test_models( else: hf_outputs = None - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + + if model in ATTN_BLOCK_SIZES: + block_size = ATTN_BLOCK_SIZES[model] + else: + block_size = 16 + + ''' + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, block_size=block_size) as vllm_model: vllm_v0_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) + ''' if model in V1_SUPPORTED_MODELS: with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") + m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER_VLLM_V1") with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, enforce_eager=True, - enable_prefix_caching=False) as vllm_model: + enable_prefix_caching=False, + block_size=block_size) as vllm_model: vllm_v1_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) else: vllm_v1_outputs = None + ''' if hf_outputs is not None: check_logprobs_close( outputs_0_lst=hf_outputs, @@ -86,6 +108,7 @@ def test_models( name_0="hf", name_1="vllm-v0", ) + ''' if model in V1_SUPPORTED_MODELS: ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs @@ -96,7 +119,7 @@ def test_models( name_1="vllm-v1", ) - +''' @pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) @@ -126,7 +149,7 @@ def test_batching( ) -@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) @@ -162,7 +185,7 @@ def test_chunked_prefill( ) -@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [10]) def test_chunked_prefill_with_parallel_sampling( vllm_runner, @@ -194,7 +217,7 @@ def test_chunked_prefill_with_parallel_sampling( vllm_model.generate(example_prompts, sampling_params) -@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [20]) def test_mamba_cache_cg_padding( vllm_runner, @@ -223,7 +246,7 @@ def test_mamba_cache_cg_padding( "Could be related to mamba cache not padded correctly") -@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [20]) def test_models_preemption_recompute( vllm_runner, @@ -251,7 +274,7 @@ def test_models_preemption_recompute( ) -@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[0]]) def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( vllm_runner, example_prompts, @@ -274,7 +297,7 @@ def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( "steps finished requests registered unnecessarily ") -@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[0]]) def test_state_cleanup( vllm_runner, example_prompts, @@ -295,7 +318,7 @@ def test_state_cleanup( "could be related to finished_requests_ids") -@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [64]) def test_multistep_correctness( vllm_runner, @@ -322,7 +345,7 @@ def test_multistep_correctness( @multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) +@pytest.mark.parametrize("model", [HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) def test_distributed_correctness( @@ -348,3 +371,4 @@ def test_distributed_correctness( name_0="vllm_tp_1", name_1="vllm_tp_2", ) +''' diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 26f537616ed2..ef6ea8bfc780 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -484,18 +484,18 @@ def forward(self, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): + + mamba_cache_params = None if not envs.VLLM_USE_V1: if self.mamba_cache is None: - num_mamba_layers = self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, LayerBlockType.mamba) self.mamba_cache = MambaCacheManager( self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, *self._get_mamba_cache_shape()) - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - else: - mamba_cache_params = None + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) hidden_states = self.model(input_ids, positions, mamba_cache_params, intermediate_tensors, inputs_embeds) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index fad1fc132d9c..70c1554bf976 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -26,6 +26,7 @@ def __init__( ): self.kv_cache_config = kv_cache_config self.max_model_len = max_model_len + self.enable_caching = enable_caching self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching, enable_kv_cache_events) @@ -273,10 +274,12 @@ def verify_and_split_kv_cache_groups(self) -> None: self.full_attention_block_size = self.full_attention_spec.block_size self.other_block_size = self.other_spec.block_size - # think this is only needed for prefix caching - #assert self.other_block_size % self.full_attention_block_size == 0, ( - # "KVCacheCoordinator assumes the block_size of full attention " - # "layers is divisible by other layers now.") + + if self.enable_caching: + # this requirement is only needed for the prefix caching logic + assert self.other_block_size % self.full_attention_block_size == 0, ( + "KVCacheCoordinator assumes the block_size of full attention " + "layers is divisible by other layers now.") if max(self.full_attention_group_ids) < min(self.other_group_ids): self.full_attn_first = True diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 94ee82b926ca..b25013522b70 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -84,12 +84,18 @@ def __init__( self.log_stats = log_stats # FIXME: make prefix cache stats conditional on log_stats self.prefix_cache_stats = PrefixCacheStats() if log_stats else None - #assert len( - # set(g.kv_cache_spec.block_size - # for g in kv_cache_config.kv_cache_groups) - #) == 1, "Only one block size is supported for now" - self.block_size = kv_cache_config.kv_cache_groups[ - 0].kv_cache_spec.block_size + + if self.enable_caching: + assert len( + set(g.kv_cache_spec.block_size + for g in kv_cache_config.kv_cache_groups) + ) == 1, "Only one block size is supported for now" + self.block_size = kv_cache_config.kv_cache_groups[ + 0].kv_cache_spec.block_size + else: + # not needed without caching + self.block_size = None + print("self.block_size: ", self.block_size) self.coordinator = get_kv_cache_coordinator( kv_cache_config=kv_cache_config, diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index d6c9de6e55d1..c368833307b6 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -456,6 +456,9 @@ def hash_request_tokens(hash_function: Any, block_size: int, """ token_ids = request.all_token_ids + print("[hash_request_tokens] block_size: ", block_size) + print("[hash_request_tokens] len(token_ids): ", len(token_ids)) + req_need_extra_keys = need_extra_keys(request) req_extra_keys = None curr_mm_idx = 0 @@ -882,9 +885,11 @@ def _get_kv_cache_config_uniform_page_size( kv_cache_groups=kv_cache_groups, ) + min_block_size = min([group.kv_cache_spec.block_size for group in kv_cache_groups]) + # Print the KV cache size and maximum concurrency. num_tokens = num_blocks // len( - grouped_layers) * vllm_config.cache_config.block_size + grouped_layers) * min_block_size num_tokens_str = f"{num_tokens:,}" logger.info("GPU KV cache size: %s tokens", num_tokens_str) max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 00b0844a5660..f7644cfd672f 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -221,6 +221,13 @@ def schedule(self) -> SchedulerOutput: num_new_tokens, self.max_model_len - request.num_computed_tokens) + if num_new_tokens > 1: + logger.info("request_id: %s", request.request_id) + logger.info("num_tokens: %d", request.num_tokens) + logger.info("num_computed_tokens: %d", + request.num_computed_tokens) + logger.info("num_new_tokens: %d", num_new_tokens) + # Schedule encoder inputs. encoder_inputs_to_schedule = None new_encoder_budget = encoder_budget @@ -431,6 +438,12 @@ def schedule(self) -> SchedulerOutput: num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 + if True: + logger.info("request_id: %s", request.request_id) + logger.info("num_tokens: %d", request.num_tokens) + logger.info("num_computed_tokens: %d", num_computed_tokens) + logger.info("num_new_tokens: %d", num_new_tokens) + # Schedule encoder inputs. if request.has_encoder_inputs: (encoder_inputs_to_schedule, num_new_tokens, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 34b7d4fc3272..8cd8b4d806ff 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -159,6 +159,7 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: class MambaSpec(KVCacheSpec): shapes: tuple[tuple[int, ...], ...] dtype: torch.dtype + multiple_of: Optional[int] def __post_init__(self): self.num_elements = sum(prod(shape) for shape in self.shapes) @@ -169,12 +170,12 @@ def type_id(self) -> str: @property def page_size_bytes(self) -> int: - real_page_size = self.num_elements * get_dtype_size(self.dtype) - hack_page_size = 528 * 4096 - print("real_page_size: ", real_page_size) - print("hack_page_size: ", hack_page_size) - assert hack_page_size >= real_page_size - return hack_page_size + page_size = self.num_elements * get_dtype_size(self.dtype) + print("real_page_size: ", page_size) + if self.multiple_of is not None: + page_size = cdiv(page_size, self.multiple_of) * self.multiple_of + print("padded page size: ", page_size) + return page_size def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: # We allocate 1 block for each request now, so max_memory_usage_bytes is diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6534c87fb96e..6ee1f19311ab 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -318,9 +318,19 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool: Returns: True if the batch was reordered, False otherwise. """ + + ''' + for i in range(0, len(self.kv_cache_config.kv_cache_groups)): + print(self.attn_metadata_builders[i]) + print(self.attn_metadata_builders[i].reorder_batch( + self.input_batch, scheduler_output)) + ''' + batch_reordered = self.attn_metadata_builders[0].reorder_batch( self.input_batch, scheduler_output) + torch.cuda.synchronize() + # For models with multiple KV cache groups, the groups should agree on # the same order of requests. We ensure this by only allowing the first # group to reorder the batch and asserting that all other groups do not @@ -328,6 +338,7 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool: for i in range(1, len(self.kv_cache_config.kv_cache_groups)): assert not self.attn_metadata_builders[i].reorder_batch( self.input_batch, scheduler_output) + return batch_reordered # Note: used for model runner override. @@ -2288,6 +2299,7 @@ def may_reinitialize_input_batch(self, kv_cache_group.kv_cache_spec.block_size for kv_cache_group in kv_cache_config.kv_cache_groups ] + print("[may_reinitialize_input_batch] block_sizes: ", block_sizes) if block_sizes != [self.cache_config.block_size]: assert self.cache_config.cpu_offload_gb == 0, ( "Cannot re-initialize the input batch when CPU weight " @@ -2515,6 +2527,17 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaMixer2) if len(mamba_layers) > 0: + + #if len(attn_layers) > 0: + if True: + # Mamba state must be padded to an integer number of + # 16th tokens worth of attention pages + attn_layer_name = next(iter(attn_layers)) + attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes + multiple_of = 16 * attn_page_size // block_size + else: + multiple_of = None + if self.vllm_config.speculative_config is not None: raise NotImplementedError( "Mamba with speculative decoding is not supported yet.") @@ -2531,5 +2554,14 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: kv_cache_spec[layer_name] = MambaSpec( shapes=mamba_module.get_state_shape(), dtype=self.kv_cache_dtype, - block_size=max_model_len) + block_size=max_model_len, + multiple_of=multiple_of) + + if len(attn_layers) > 0: + mamba_layer_name = next(iter(mamba_layers)) + mamba_page_size = kv_cache_spec[mamba_layer_name].page_size_bytes + if attn_page_size < mamba_page_size: + suggested_attn_block_size = cdiv(mamba_page_size, multiple_of)*16 + raise ValueError(f"Attention block size must be increased to {suggested_attn_block_size} in order to match mamba page size") + return kv_cache_spec From 300d25fd2bac951875caa7b6277de2158716ae13 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 24 Jun 2025 09:53:18 +0000 Subject: [PATCH 03/28] Working version Signed-off-by: Thomas Parnell --- .../models/language/generation/test_hybrid.py | 66 ++++++++----------- vllm/model_executor/models/bamba.py | 18 ++--- vllm/v1/core/kv_cache_coordinator.py | 13 ++-- vllm/v1/core/kv_cache_manager.py | 6 +- vllm/v1/core/kv_cache_utils.py | 29 ++------ vllm/v1/core/sched/scheduler.py | 13 ---- vllm/v1/core/single_type_kv_cache_manager.py | 4 -- vllm/v1/kv_cache_interface.py | 2 - vllm/v1/worker/gpu_model_runner.py | 34 ++++------ 9 files changed, 58 insertions(+), 127 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 6d343be0deb6..82820947d9f7 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -14,47 +14,45 @@ # The rest of the models will only be tested by test_models SSM_MODELS = [ - #"state-spaces/mamba-130m-hf", - #"tiiuae/falcon-mamba-tiny-dev", + "state-spaces/mamba-130m-hf", + "tiiuae/falcon-mamba-tiny-dev", # TODO: Compare to a Mamba2 model. The HF transformers implementation of # Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test # doesn't compare vLLM output with HF output. # See https://github.com/huggingface/transformers/pull/35943 - #"mistralai/Mamba-Codestral-7B-v0.1", + "mistralai/Mamba-Codestral-7B-v0.1", ] HYBRID_MODELS = [ - #"ai21labs/Jamba-tiny-dev", + "ai21labs/Jamba-tiny-dev", # NOTE: ibm-granite/granite-4.0-tiny-preview are skipped currently as # it is not yet available in huggingface transformers # "ibm-granite/granite-4.0-tiny-preview", # NOTE: Running Plamo2 in transformers implementation requires to install # causal-conv1d package, which is not listed as a test dependency as it's # not compatible with pip-compile. - #"pfnet/plamo-2-1b", - #"Zyphra/Zamba2-1.2B-instruct", - #"hmellor/tiny-random-BambaForCausalLM", - "/net/storage149/autofs/css22/nmg/models/hf/ibm-ai-platform/Bamba-9B-v1/main/", + "pfnet/plamo-2-1b", + "Zyphra/Zamba2-1.2B-instruct", + "hmellor/tiny-random-BambaForCausalLM", + "ibm-ai-platform/Bamba-9B-v1", ] V1_SUPPORTED_MODELS = [ "mistralai/Mamba-Codestral-7B-v0.1", - #"hmellor/tiny-random-BambaForCausalLM", - "/net/storage149/autofs/css22/nmg/models/hf/ibm-ai-platform/Bamba-9B-v1/main/" + "ibm-ai-platform/Bamba-9B-v1", ] - ATTN_BLOCK_SIZES = { - "hmellor/tiny-random-BambaForCausalLM": 48, - "/net/storage149/autofs/css22/nmg/models/hf/ibm-ai-platform/Bamba-9B-v1/main/": 528 + "ibm-ai-platform/Bamba-9B-v1": 528, } # Avoid OOM MAX_NUM_SEQS = 4 + @pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) @pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("num_logprobs", [10]) def test_models( hf_runner, vllm_runner, @@ -64,9 +62,6 @@ def test_models( max_tokens: int, num_logprobs: int, ) -> None: - - #example_prompts = ["Hello World "] - with hf_runner(model) as hf_model: if model != "mistralai/Mamba-Codestral-7B-v0.1": hf_outputs = hf_model.generate_greedy_logprobs_limit( @@ -74,22 +69,18 @@ def test_models( else: hf_outputs = None - - if model in ATTN_BLOCK_SIZES: - block_size = ATTN_BLOCK_SIZES[model] - else: - block_size = 16 - - ''' - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, block_size=block_size) as vllm_model: + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: vllm_v0_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) - ''' if model in V1_SUPPORTED_MODELS: with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER_VLLM_V1") + if model in HYBRID_MODELS: + # required due to reorder_batch behaviour + m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") + # set attn block size to match mamba state + block_size = ATTN_BLOCK_SIZES.get(model, 16) with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, enforce_eager=True, @@ -100,7 +91,6 @@ def test_models( else: vllm_v1_outputs = None - ''' if hf_outputs is not None: check_logprobs_close( outputs_0_lst=hf_outputs, @@ -108,7 +98,6 @@ def test_models( name_0="hf", name_1="vllm-v0", ) - ''' if model in V1_SUPPORTED_MODELS: ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs @@ -119,7 +108,7 @@ def test_models( name_1="vllm-v1", ) -''' + @pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) @@ -149,7 +138,7 @@ def test_batching( ) -@pytest.mark.parametrize("model", [HYBRID_MODELS[0]]) +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) @@ -185,7 +174,7 @@ def test_chunked_prefill( ) -@pytest.mark.parametrize("model", [HYBRID_MODELS[0]]) +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [10]) def test_chunked_prefill_with_parallel_sampling( vllm_runner, @@ -217,7 +206,7 @@ def test_chunked_prefill_with_parallel_sampling( vllm_model.generate(example_prompts, sampling_params) -@pytest.mark.parametrize("model", [HYBRID_MODELS[0]]) +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [20]) def test_mamba_cache_cg_padding( vllm_runner, @@ -246,7 +235,7 @@ def test_mamba_cache_cg_padding( "Could be related to mamba cache not padded correctly") -@pytest.mark.parametrize("model", [HYBRID_MODELS[0]]) +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [20]) def test_models_preemption_recompute( vllm_runner, @@ -274,7 +263,7 @@ def test_models_preemption_recompute( ) -@pytest.mark.parametrize("model", [HYBRID_MODELS[0]]) +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( vllm_runner, example_prompts, @@ -297,7 +286,7 @@ def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( "steps finished requests registered unnecessarily ") -@pytest.mark.parametrize("model", [HYBRID_MODELS[0]]) +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) def test_state_cleanup( vllm_runner, example_prompts, @@ -318,7 +307,7 @@ def test_state_cleanup( "could be related to finished_requests_ids") -@pytest.mark.parametrize("model", [HYBRID_MODELS[0]]) +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [64]) def test_multistep_correctness( vllm_runner, @@ -345,7 +334,7 @@ def test_multistep_correctness( @multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("model", [HYBRID_MODELS[0]]) +@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) def test_distributed_correctness( @@ -371,4 +360,3 @@ def test_distributed_correctness( name_0="vllm_tp_1", name_1="vllm_tp_2", ) -''' diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index ef6ea8bfc780..d743c52074c6 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -37,7 +37,7 @@ from vllm.utils import LayerBlockType from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, - SupportsQuant, SupportsV0Only) + SupportsQuant) from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -314,8 +314,6 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - print("self.config.mamba_chunk_size: ", self.config.mamba_chunk_size) - attn_metadata = get_forward_context().attn_metadata if not envs.VLLM_USE_V1: @@ -346,7 +344,8 @@ def forward( num_attn += 1 layer_mamba_cache_params = None - if isinstance(layer, BambaMixerDecoderLayer) and mamba_cache_params: + if isinstance(layer, + BambaMixerDecoderLayer) and mamba_cache_params: layer_mamba_cache_params = mamba_cache_params.at_layer_idx( i - num_attn) @@ -488,12 +487,15 @@ def forward(self, mamba_cache_params = None if not envs.VLLM_USE_V1: if self.mamba_cache is None: - num_mamba_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) + num_mamba_layers = \ + self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, + LayerBlockType.mamba + ) self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, - *self._get_mamba_cache_shape()) + self.vllm_config, self.lm_head.weight.dtype, + num_mamba_layers, *self._get_mamba_cache_shape()) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 70c1554bf976..9ff1dc23ea42 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -42,8 +42,6 @@ def __init__( ) for i, kv_cache_group in enumerate( self.kv_cache_config.kv_cache_groups)) - print("single_type_managers: ", self.single_type_managers) - def get_num_blocks_to_allocate( self, request_id: str, num_tokens: int, new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> int: @@ -60,14 +58,10 @@ def get_num_blocks_to_allocate( Returns: The number of blocks. """ - print("request_id: ", request_id) - print("num_tokens: ", num_tokens) - print("new_computed_blocks: ", new_computed_blocks) num_blocks_to_allocate = 0 for i, manager in enumerate(self.single_type_managers): num_blocks_to_allocate += manager.get_num_blocks_to_allocate( request_id, num_tokens, new_computed_blocks[i]) - print("i: ", i, " manager: ", manager, " num_blocks_to_allocate: ", num_blocks_to_allocate) return num_blocks_to_allocate def save_new_computed_blocks( @@ -277,9 +271,10 @@ def verify_and_split_kv_cache_groups(self) -> None: if self.enable_caching: # this requirement is only needed for the prefix caching logic - assert self.other_block_size % self.full_attention_block_size == 0, ( - "KVCacheCoordinator assumes the block_size of full attention " - "layers is divisible by other layers now.") + divisible = self.other_block_size % self.full_attention_block_size + assert divisible == 0, ( + "KVCacheCoordinator assumes the block_size of full " + "attention layers is divisible by other layers now.") if max(self.full_attention_group_ids) < min(self.other_group_ids): self.full_attn_first = True diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index b25013522b70..6937455e7d85 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -85,6 +85,7 @@ def __init__( # FIXME: make prefix cache stats conditional on log_stats self.prefix_cache_stats = PrefixCacheStats() if log_stats else None + self.block_size: Optional[int] = None if self.enable_caching: assert len( set(g.kv_cache_spec.block_size @@ -92,11 +93,7 @@ def __init__( ) == 1, "Only one block size is supported for now" self.block_size = kv_cache_config.kv_cache_groups[ 0].kv_cache_spec.block_size - else: - # not needed without caching - self.block_size = None - print("self.block_size: ", self.block_size) self.coordinator = get_kv_cache_coordinator( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, @@ -160,6 +157,7 @@ def get_computed_blocks(self, # if the scheduler has tried to schedule the request before. block_hashes = self.req_to_block_hashes[request.request_id] if not block_hashes: + assert self.block_size is not None block_hashes = hash_request_tokens(self.caching_hash_fn, self.block_size, request) self.req_to_block_hashes[request.request_id] = block_hashes diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index c368833307b6..2fbcb569e3d5 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -456,9 +456,6 @@ def hash_request_tokens(hash_function: Any, block_size: int, """ token_ids = request.all_token_ids - print("[hash_request_tokens] block_size: ", block_size) - print("[hash_request_tokens] len(token_ids): ", len(token_ids)) - req_need_extra_keys = need_extra_keys(request) req_extra_keys = None curr_mm_idx = 0 @@ -608,7 +605,6 @@ def create_kv_cache_group_specs( merged_layer_spec = layer_specs[0].merge(layer_specs) kv_cache_groups.append( KVCacheGroupSpec(layer_names_one_group, merged_layer_spec)) - print("len(kv_cache_groups): ", len(kv_cache_groups)) return kv_cache_groups @@ -696,6 +692,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, page_size = get_uniform_page_size(kv_cache_spec) num_blocks = get_num_blocks(vllm_config, len(kv_cache_spec), available_memory, page_size) + per_layer_size = page_size * num_blocks # All layers have the same KV cache spec, so we create one kv cache group # for all layers. @@ -737,8 +734,6 @@ def is_kv_cache_page_size_uniform( """ page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} - print("page_sizes: ", page_sizes) - return len(page_sizes) == 1 @@ -816,10 +811,6 @@ def _get_kv_cache_config_uniform_page_size( for layer_name, layer_spec in kv_cache_spec.items(): same_type_layers[layer_spec.type_id].append(layer_name) - - for layer_type in same_type_layers.keys(): - print(layer_type, len(same_type_layers[layer_type])) - # Split each group into smaller groups, to make the number of layers in each # group identical. Add padding to the last group of each type if necessary. # E.g., (full.0, full.1), (sw.0, sw.1, sw.2) @@ -833,9 +824,6 @@ def _get_kv_cache_config_uniform_page_size( # strategy if we want to support more complex patterns (e.g., 20 full + 30 # sw, where the group size should be 10). group_size = min([len(layers) for layers in same_type_layers.values()]) - - print("group_size: ", group_size) - grouped_layers = [] for layers in same_type_layers.values(): num_padding_layers = group_size - len(layers) % group_size @@ -858,26 +846,17 @@ def _get_kv_cache_config_uniform_page_size( # full.0, sw.0, sw.2: share a Tensor with size=available_memory//2 # full.1, sw.1: share another Tensor with size=available_memory//2 page_size = get_uniform_page_size(kv_cache_spec) - print("page_size: ", page_size) num_blocks = get_num_blocks(vllm_config, group_size, available_memory, page_size) - print("num_blocks: ", num_blocks) per_memory_pool_size = page_size * num_blocks - print("per_memory_pool_size: ", per_memory_pool_size) - kv_cache_tensors = [] for i in range(group_size): - print("i: ", i) shared_by = [] for j in range(len(kv_cache_groups)): - print("j: ", j) - print("grouped_layers[j]: ", grouped_layers[j]) if i < len(grouped_layers[j]): shared_by.append(grouped_layers[j][i]) - print("shared_by: ", shared_by) kv_cache_tensors.append( KVCacheTensor(size=per_memory_pool_size, shared_by=shared_by)) - print("kv_cache_tensors: ", kv_cache_tensors) kv_cache_config = KVCacheConfig( num_blocks=num_blocks, @@ -885,11 +864,11 @@ def _get_kv_cache_config_uniform_page_size( kv_cache_groups=kv_cache_groups, ) - min_block_size = min([group.kv_cache_spec.block_size for group in kv_cache_groups]) + min_block_size = min( + [group.kv_cache_spec.block_size for group in kv_cache_groups]) # Print the KV cache size and maximum concurrency. - num_tokens = num_blocks // len( - grouped_layers) * min_block_size + num_tokens = num_blocks // len(grouped_layers) * min_block_size num_tokens_str = f"{num_tokens:,}" logger.info("GPU KV cache size: %s tokens", num_tokens_str) max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f7644cfd672f..00b0844a5660 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -221,13 +221,6 @@ def schedule(self) -> SchedulerOutput: num_new_tokens, self.max_model_len - request.num_computed_tokens) - if num_new_tokens > 1: - logger.info("request_id: %s", request.request_id) - logger.info("num_tokens: %d", request.num_tokens) - logger.info("num_computed_tokens: %d", - request.num_computed_tokens) - logger.info("num_new_tokens: %d", num_new_tokens) - # Schedule encoder inputs. encoder_inputs_to_schedule = None new_encoder_budget = encoder_budget @@ -438,12 +431,6 @@ def schedule(self) -> SchedulerOutput: num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 - if True: - logger.info("request_id: %s", request.request_id) - logger.info("num_tokens: %d", request.num_tokens) - logger.info("num_computed_tokens: %d", num_computed_tokens) - logger.info("num_new_tokens: %d", num_new_tokens) - # Schedule encoder inputs. if request.has_encoder_inputs: (encoder_inputs_to_schedule, num_new_tokens, diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 4ccba65ebfbe..5b4718038076 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -121,10 +121,6 @@ def allocate_new_blocks(self, request_id: str, req_blocks = self.req_to_blocks[request_id] num_required_blocks = cdiv(num_tokens, self.block_size) num_new_blocks = num_required_blocks - len(req_blocks) - print("[allocate_new_blocks] num_tokens: ", num_tokens) - print("[allocate_new_blocks] block_size: ", self.block_size) - print("[allocate_new_blocks] num_required_blocks: ", num_required_blocks) - if num_new_blocks <= 0: return [] else: diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 8cd8b4d806ff..f1e0837dd140 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -171,10 +171,8 @@ def type_id(self) -> str: @property def page_size_bytes(self) -> int: page_size = self.num_elements * get_dtype_size(self.dtype) - print("real_page_size: ", page_size) if self.multiple_of is not None: page_size = cdiv(page_size, self.multiple_of) * self.multiple_of - print("padded page size: ", page_size) return page_size def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9651bf1a1694..148960da1071 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -320,19 +320,9 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool: Returns: True if the batch was reordered, False otherwise. """ - - ''' - for i in range(0, len(self.kv_cache_config.kv_cache_groups)): - print(self.attn_metadata_builders[i]) - print(self.attn_metadata_builders[i].reorder_batch( - self.input_batch, scheduler_output)) - ''' - batch_reordered = self.attn_metadata_builders[0].reorder_batch( self.input_batch, scheduler_output) - torch.cuda.synchronize() - # For models with multiple KV cache groups, the groups should agree on # the same order of requests. We ensure this by only allowing the first # group to reorder the batch and asserting that all other groups do not @@ -340,7 +330,6 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool: for i in range(1, len(self.kv_cache_config.kv_cache_groups)): assert not self.attn_metadata_builders[i].reorder_batch( self.input_batch, scheduler_output) - return batch_reordered # Note: used for model runner override. @@ -2301,7 +2290,6 @@ def may_reinitialize_input_batch(self, kv_cache_group.kv_cache_spec.block_size for kv_cache_group in kv_cache_config.kv_cache_groups ] - print("[may_reinitialize_input_batch] block_sizes: ", block_sizes) if block_sizes != [self.cache_config.block_size]: assert self.cache_config.cpu_offload_gb == 0, ( "Cannot re-initialize the input batch when CPU weight " @@ -2369,8 +2357,6 @@ def _reshape_kv_cache_tensors( assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 num_blocks = (raw_tensor.numel() // kv_cache_spec.page_size_bytes) - print("layer_name: ", layer_name) - print("num_blocks: ", num_blocks) if isinstance(kv_cache_spec, AttentionSpec): kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, @@ -2404,18 +2390,17 @@ def _reshape_kv_cache_tensors( dtype = kv_cache_spec.dtype state_tensors = [] start_pos = 0 - print("kv_cache_spec.shapes: ", kv_cache_spec.shapes) for shape in kv_cache_spec.shapes: target_shape = (num_blocks, *shape) size_in_bytes = np.prod(shape) * get_dtype_size( dtype) * num_blocks - print("size_in_bytes: ", size_in_bytes) tensor = raw_tensor[start_pos:start_pos + size_in_bytes] tensor = tensor.view(dtype).view(target_shape) state_tensors.append(tensor) start_pos += size_in_bytes - #assert start_pos == raw_tensor.numel() + if kv_cache_spec.multiple_of is None: + assert start_pos == raw_tensor.numel() kv_caches[layer_name] = tuple(state_tensors) else: raise NotImplementedError @@ -2529,9 +2514,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaMixer2) if len(mamba_layers) > 0: - - #if len(attn_layers) > 0: - if True: + if len(attn_layers) > 0: # Mamba state must be padded to an integer number of # 16th tokens worth of attention pages attn_layer_name = next(iter(attn_layers)) @@ -2561,9 +2544,14 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: if len(attn_layers) > 0: mamba_layer_name = next(iter(mamba_layers)) - mamba_page_size = kv_cache_spec[mamba_layer_name].page_size_bytes + mamba_page_size = kv_cache_spec[ + mamba_layer_name].page_size_bytes if attn_page_size < mamba_page_size: - suggested_attn_block_size = cdiv(mamba_page_size, multiple_of)*16 - raise ValueError(f"Attention block size must be increased to {suggested_attn_block_size} in order to match mamba page size") + required_attn_block_size = cdiv(mamba_page_size, + multiple_of) * 16 + raise ValueError( + "Attention block size must be increased to " + f"{required_attn_block_size} in order to match " + "the mamba page size") return kv_cache_spec From 08223081c5a58e2a8b7dcdd252b81d35aafe18b8 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 24 Jun 2025 21:07:07 +0000 Subject: [PATCH 04/28] Add support + test for Zamba2 Signed-off-by: Thomas Parnell --- .../models/language/generation/test_hybrid.py | 14 ++- vllm/model_executor/models/zamba2.py | 101 +++++++++++------- 2 files changed, 74 insertions(+), 41 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 82820947d9f7..e92892021916 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -34,16 +34,19 @@ "pfnet/plamo-2-1b", "Zyphra/Zamba2-1.2B-instruct", "hmellor/tiny-random-BambaForCausalLM", - "ibm-ai-platform/Bamba-9B-v1", + # too big for CI + #"ibm-ai-platform/Bamba-9B-v1", ] V1_SUPPORTED_MODELS = [ "mistralai/Mamba-Codestral-7B-v0.1", "ibm-ai-platform/Bamba-9B-v1", + "Zyphra/Zamba2-1.2B-instruct", ] ATTN_BLOCK_SIZES = { "ibm-ai-platform/Bamba-9B-v1": 528, + "Zyphra/Zamba2-1.2B-instruct": 80, } # Avoid OOM @@ -52,7 +55,7 @@ @pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) @pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [10]) +@pytest.mark.parametrize("num_logprobs", [5]) def test_models( hf_runner, vllm_runner, @@ -74,13 +77,16 @@ def test_models( example_prompts, max_tokens, num_logprobs) if model in V1_SUPPORTED_MODELS: + if model in HYBRID_MODELS and model in ATTN_BLOCK_SIZES: + block_size = ATTN_BLOCK_SIZES[model] + else: + block_size = 16 + with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") if model in HYBRID_MODELS: # required due to reorder_batch behaviour m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") - # set attn block size to match mamba state - block_size = ATTN_BLOCK_SIZES.get(model, 16) with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, enforce_eager=True, diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index a4f97c774f70..c273f70777d2 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -15,6 +15,7 @@ from torch import nn from transformers import Zamba2Config +from vllm import envs from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size @@ -41,7 +42,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import HasInnerState, IsHybrid, SupportsV0Only +from .interfaces import HasInnerState, IsHybrid from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix @@ -58,6 +59,7 @@ def __init__( rank: int, output_dim: Union[int, list[int]], quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): """Initialize the attention layer. @@ -283,6 +285,7 @@ def __init__( bare_block_idx: int, num_hybrid_layers: dict[int, int], quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: """Initialize the MLP layer. @@ -471,11 +474,10 @@ class Zamba2MambaDecoderLayer(nn.Module): computation depending on configuration. """ - def __init__( - self, - config: Zamba2Config, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: + def __init__(self, + config: Zamba2Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: """Initialize the Mamba decoder layer. Args: @@ -486,20 +488,21 @@ def __init__( # Initialize Mamba mixer with expanded intermediate size intermediate_size = config.mamba_expand * config.hidden_size - self.mamba = MambaMixer2( - hidden_size=config.hidden_size, - ssm_state_size=config.mamba_d_state, - conv_kernel_size=config.mamba_d_conv, - intermediate_size=intermediate_size, - use_conv_bias=config.use_conv_bias, - use_bias=config.add_bias_linear, - n_groups=config.mamba_ngroups, - num_heads=config.n_mamba_heads, - head_dim=intermediate_size // config.n_mamba_heads, - rms_norm_eps=config.rms_norm_eps, - activation="silu", - quant_config=quant_config, - ) + self.mamba = MambaMixer2(hidden_size=config.hidden_size, + ssm_state_size=config.mamba_d_state, + conv_kernel_size=config.mamba_d_conv, + intermediate_size=intermediate_size, + use_conv_bias=config.use_conv_bias, + use_bias=config.add_bias_linear, + n_groups=config.mamba_ngroups, + num_heads=config.n_mamba_heads, + head_dim=intermediate_size // + config.n_mamba_heads, + rms_norm_eps=config.rms_norm_eps, + activation="silu", + quant_config=quant_config, + prefix=f"{prefix}.mixer", + chunk_size=config.chunk_size) # Input normalization self.input_layernorm = RMSNorm(config.hidden_size, @@ -573,6 +576,7 @@ def __init__( config: Zamba2Config, block_idx: int, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: """Initialize the hybrid layer. @@ -589,7 +593,8 @@ def __init__( bias=False, quant_config=quant_config) self.mamba_decoder = Zamba2MambaDecoderLayer(config, - quant_config=quant_config) + quant_config=quant_config, + prefix=prefix) def forward( self, @@ -699,14 +704,23 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: # Initialize layers according to block type configuration layers = [] for layer_idx, layer_type in enumerate(config.layers_block_type): + # tpa: avoid layers getting same index + # somewhat hacky but correct (I think) + prefix = str(len(layer2block_map) + layer_idx) if layer_type == "hybrid": block = next(blocks) block_idx = layer2block_map[layer_idx] layers.append( - Zamba2HybridLayer(block, config, block_idx, quant_config)) + Zamba2HybridLayer(block, + config, + block_idx, + quant_config, + prefix=prefix)) else: layers.append( - Zamba2MambaDecoderLayer(config, quant_config=quant_config)) + Zamba2MambaDecoderLayer(config, + quant_config=quant_config, + prefix=prefix)) self.layers = nn.ModuleList(layers) # Final layer normalization @@ -751,19 +765,30 @@ def forward( attn_metadata = get_forward_context().attn_metadata - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.chunk_size, - attn_metadata=attn_metadata, - ) + if not envs.VLLM_USE_V1: + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.chunk_size, + attn_metadata=attn_metadata, + ) + else: + # v1 get mamba2_metadata from forward_context + mamba2_metadata = None # Process through layers original_hidden_states = torch.clone(hidden_states) for layer_idx, layer in enumerate(self.layers): + + layer_mamba_cache_params = None + if (isinstance(layer, (Zamba2HybridLayer, Zamba2MambaDecoderLayer)) + and mamba_cache_params): + layer_mamba_cache_params = mamba_cache_params.at_layer_idx( + layer_idx) + layer_outputs = layer( hidden_states, original_hidden_states=original_hidden_states, positions=positions, - mamba_cache_params=mamba_cache_params.at_layer_idx(layer_idx), + mamba_cache_params=layer_mamba_cache_params, mamba2_metadata=mamba2_metadata, ) hidden_states = layer_outputs @@ -803,7 +828,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): +class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): """Zamba2 model with causal language modeling head. This class wraps the core Zamba2 model and adds: @@ -897,14 +922,16 @@ def forward(self, Output hidden states """ # Initialize Mamba cache if needed - if self.mamba_cache is None: - num_mamba_layers = self.config.num_hidden_layers - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, - *self._get_mamba_cache_shape()) - - # Get cache parameters for current run - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + mamba_cache_params = None + if not envs.VLLM_USE_V1: + if self.mamba_cache is None: + num_mamba_layers = self.config.num_hidden_layers + self.mamba_cache = MambaCacheManager( + self.vllm_config, self.lm_head.weight.dtype, + num_mamba_layers, *self._get_mamba_cache_shape()) + + # Get cache parameters for current run + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) # Forward pass through model hidden_states = self.model( From a9fc73f40dc0537d87256c851d76802b3ae5b873 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 26 Jun 2025 09:20:08 +0000 Subject: [PATCH 05/28] Fix memory layout for KV cache tensors in mamba case Signed-off-by: Thomas Parnell --- vllm/v1/worker/gpu_model_runner.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 148960da1071..af4f3e714667 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2388,22 +2388,27 @@ def _reshape_kv_cache_tensors( elif isinstance(kv_cache_spec, MambaSpec): raw_tensor = kv_cache_raw_tensors[layer_name] dtype = kv_cache_spec.dtype + page_size = kv_cache_spec.page_size_bytes // get_dtype_size( + dtype) state_tensors = [] - start_pos = 0 + storage_offset = 0 for shape in kv_cache_spec.shapes: target_shape = (num_blocks, *shape) - size_in_bytes = np.prod(shape) * get_dtype_size( - dtype) * num_blocks - tensor = raw_tensor[start_pos:start_pos + - size_in_bytes] - tensor = tensor.view(dtype).view(target_shape) + stride = torch.empty(target_shape).stride() + target_stride = (page_size, *stride[1:]) + tensor = torch.as_strided( + raw_tensor.view(dtype), + size=target_shape, + stride=target_stride, + storage_offset=storage_offset, + ) state_tensors.append(tensor) - start_pos += size_in_bytes - if kv_cache_spec.multiple_of is None: - assert start_pos == raw_tensor.numel() - kv_caches[layer_name] = tuple(state_tensors) + storage_offset += stride[0] + + kv_caches[layer_name] = state_tensors else: raise NotImplementedError + return kv_caches def initialize_kv_cache_tensors( From 0e5b6de2311e47897d1bf5c660dfd153e94bdeb4 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 26 Jun 2025 11:16:41 +0000 Subject: [PATCH 06/28] kv_cache_interface.py: use utils.round_up Signed-off-by: Thomas Parnell --- vllm/v1/kv_cache_interface.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index f1e0837dd140..31e2ac23f5e4 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -11,7 +11,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.utils import cdiv, get_dtype_size +from vllm.utils import cdiv, get_dtype_size, round_up logger = init_logger(__name__) @@ -172,7 +172,7 @@ def type_id(self) -> str: def page_size_bytes(self) -> int: page_size = self.num_elements * get_dtype_size(self.dtype) if self.multiple_of is not None: - page_size = cdiv(page_size, self.multiple_of) * self.multiple_of + page_size = round_up(page_size, self.multiple_of) return page_size def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: From 89f504a1e2a593d9af74bccd20b3188c21ba9599 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 26 Jun 2025 11:46:05 +0000 Subject: [PATCH 07/28] Fix unrelated CI test failing Signed-off-by: Thomas Parnell --- .../models/language/generation/test_gemma.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/tests/models/language/generation/test_gemma.py b/tests/models/language/generation/test_gemma.py index ed0f0c19a041..3d4ceb17a623 100644 --- a/tests/models/language/generation/test_gemma.py +++ b/tests/models/language/generation/test_gemma.py @@ -7,14 +7,17 @@ @pytest.mark.parametrize("model", MODELS) -def test_dummy_loader(vllm_runner, model: str) -> None: - with vllm_runner( - model, - load_format="dummy", - ) as llm: - normalizers = llm.collective_rpc(lambda self: self.worker.model_runner. - model.model.normalizer.cpu().item()) - assert np.allclose( - normalizers, - llm.llm_engine.model_config.hf_config.hidden_size**0.5, - rtol=1e-3) +def test_dummy_loader(vllm_runner, monkeypatch, model: str) -> None: + with monkeypatch.context() as m: + m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + with vllm_runner( + model, + load_format="dummy", + ) as llm: + normalizers = llm.model.collective_rpc( + lambda self: self.model_runner.model.model.normalizer.cpu( + ).item()) + assert np.allclose( + normalizers, + llm.model.llm_engine.model_config.hf_config.hidden_size**0.5, + rtol=1e-3) From 0b7783b6d0d16a7293f22d6690c007ed2d2df5fc Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 26 Jun 2025 11:47:43 +0000 Subject: [PATCH 08/28] Enable bamba-9b in CI Signed-off-by: Thomas Parnell --- tests/models/language/generation/test_hybrid.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index e92892021916..6ca7db0c5f8d 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -34,8 +34,7 @@ "pfnet/plamo-2-1b", "Zyphra/Zamba2-1.2B-instruct", "hmellor/tiny-random-BambaForCausalLM", - # too big for CI - #"ibm-ai-platform/Bamba-9B-v1", + "ibm-ai-platform/Bamba-9B-v1", ] V1_SUPPORTED_MODELS = [ From ded4833d7e27343ddcbcea9000e3d38668c10d5f Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 26 Jun 2025 15:38:10 +0000 Subject: [PATCH 09/28] Fix unrelated CI issue Signed-off-by: Thomas Parnell --- tests/models/language/generation/test_gemma.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/models/language/generation/test_gemma.py b/tests/models/language/generation/test_gemma.py index 3d4ceb17a623..5be4ae874e61 100644 --- a/tests/models/language/generation/test_gemma.py +++ b/tests/models/language/generation/test_gemma.py @@ -14,10 +14,14 @@ def test_dummy_loader(vllm_runner, monkeypatch, model: str) -> None: model, load_format="dummy", ) as llm: - normalizers = llm.model.collective_rpc( - lambda self: self.model_runner.model.model.normalizer.cpu( - ).item()) - assert np.allclose( - normalizers, - llm.model.llm_engine.model_config.hf_config.hidden_size**0.5, - rtol=1e-3) + if model == "google/gemma-3-4b-it": + normalizers = llm.model.collective_rpc( + lambda self: self.model_runner.model.language_model.model. + normalizer.cpu().item()) + config = llm.model.llm_engine.model_config.hf_config.text_config + else: + normalizers = llm.model.collective_rpc( + lambda self: self.model_runner.model.model.normalizer.cpu( + ).item()) + config = llm.model.llm_engine.model_config.hf_config + assert np.allclose(normalizers, config.hidden_size**0.5, rtol=2e-3) From 31db869218bd32b948f92cfbca49421a7aa27c9f Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 27 Jun 2025 13:04:21 +0000 Subject: [PATCH 10/28] Add support for Nemotron-H Signed-off-by: Thomas Parnell --- .../models/language/generation/test_hybrid.py | 10 ++++- vllm/model_executor/models/nemotron_h.py | 45 ++++++++++++------- 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 6ca7db0c5f8d..2e6ae966c4d2 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -35,17 +35,22 @@ "Zyphra/Zamba2-1.2B-instruct", "hmellor/tiny-random-BambaForCausalLM", "ibm-ai-platform/Bamba-9B-v1", + # Note: I'm not seeing the same output from vLLM V0 vs. HF transformers + # for Nemotron-H-8B; currently only compare vLLM V0 vs. vLLM V1 + "nvidia/Nemotron-H-8B-Base-8K", ] V1_SUPPORTED_MODELS = [ "mistralai/Mamba-Codestral-7B-v0.1", "ibm-ai-platform/Bamba-9B-v1", "Zyphra/Zamba2-1.2B-instruct", + "nvidia/Nemotron-H-8B-Base-8K", ] ATTN_BLOCK_SIZES = { "ibm-ai-platform/Bamba-9B-v1": 528, "Zyphra/Zamba2-1.2B-instruct": 80, + "nvidia/Nemotron-H-8B-Base-8K": 528, } # Avoid OOM @@ -65,7 +70,10 @@ def test_models( num_logprobs: int, ) -> None: with hf_runner(model) as hf_model: - if model != "mistralai/Mamba-Codestral-7B-v0.1": + if model not in [ + "mistralai/Mamba-Codestral-7B-v0.1", + "nvidia/Nemotron-H-8B-Base-8K" + ]: hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) else: diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 3424efa80d48..5d51b01df9db 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -23,6 +23,7 @@ import torch from torch import nn +from vllm import envs from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size @@ -44,8 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, - SupportsQuant, - SupportsV0Only) + SupportsQuant) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.models.utils import ( @@ -153,6 +153,8 @@ def __init__( rms_norm_eps=config.rms_norm_eps, activation=config.mamba_hidden_act, quant_config=quant_config, + prefix=f"{prefix}.mixer", + chunk_size=config.chunk_size, ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -348,10 +350,14 @@ def forward( attn_metadata = get_forward_context().attn_metadata - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.chunk_size, - attn_metadata=attn_metadata, - ) + if not envs.VLLM_USE_V1: + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.chunk_size, + attn_metadata=attn_metadata, + ) + else: + # v1 get mamba2_metadata from forward_context + mamba2_metadata = None if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -369,7 +375,8 @@ def forward( for i in range(len(self.layers)): layer = self.layers[i] layer_mamba_cache_params = None - if isinstance(layer, NemotronHMambaDecoderLayer): + if isinstance(layer, + NemotronHMambaDecoderLayer) and mamba_cache_params: layer_mamba_cache_params = mamba_cache_params.at_layer_idx( i - num_non_mamba_layers) else: @@ -437,7 +444,7 @@ def load_weights(self, weights: Iterable[tuple[str, class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybrid, SupportsV0Only, SupportsQuant): + IsHybrid, SupportsQuant): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -499,15 +506,23 @@ def forward(self, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - if self.mamba_cache is None: - num_mamba_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) + mamba_cache_params = None + if not envs.VLLM_USE_V1: + if self.mamba_cache is None: + + num_mamba_layers = \ + self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, + LayerBlockType.mamba + ) + + self.mamba_cache = MambaCacheManager( + self.vllm_config, self.lm_head.weight.dtype, + num_mamba_layers, *self._get_mamba_cache_shape()) + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, - *self._get_mamba_cache_shape()) - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) hidden_states = self.model(input_ids, positions, mamba_cache_params, intermediate_tensors, inputs_embeds) From c45e7e5e97baccc54e0196ac31a2d7111959ef58 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 27 Jun 2025 13:21:18 +0000 Subject: [PATCH 11/28] Enable Granite 4.0 (HybridMoE) Signed-off-by: Thomas Parnell --- .../models/language/generation/test_hybrid.py | 28 +++++++---- .../model_executor/models/granitemoehybrid.py | 50 ++++++++++++------- 2 files changed, 50 insertions(+), 28 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 2e6ae966c4d2..d4df38c879c3 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -16,18 +16,11 @@ SSM_MODELS = [ "state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev", - # TODO: Compare to a Mamba2 model. The HF transformers implementation of - # Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test - # doesn't compare vLLM output with HF output. - # See https://github.com/huggingface/transformers/pull/35943 "mistralai/Mamba-Codestral-7B-v0.1", ] HYBRID_MODELS = [ "ai21labs/Jamba-tiny-dev", - # NOTE: ibm-granite/granite-4.0-tiny-preview are skipped currently as - # it is not yet available in huggingface transformers - # "ibm-granite/granite-4.0-tiny-preview", # NOTE: Running Plamo2 in transformers implementation requires to install # causal-conv1d package, which is not listed as a test dependency as it's # not compatible with pip-compile. @@ -35,9 +28,23 @@ "Zyphra/Zamba2-1.2B-instruct", "hmellor/tiny-random-BambaForCausalLM", "ibm-ai-platform/Bamba-9B-v1", + "nvidia/Nemotron-H-8B-Base-8K", + "ibm-granite/granite-4.0-tiny-preview" +] + +HF_UNSUPPORTED_MODELS = [ + # The HF transformers implementation of + # Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test + # doesn't compare vLLM output with HF output. + # See https://github.com/huggingface/transformers/pull/35943 + "mistralai/Mamba-Codestral-7B-v0.1", # Note: I'm not seeing the same output from vLLM V0 vs. HF transformers # for Nemotron-H-8B; currently only compare vLLM V0 vs. vLLM V1 "nvidia/Nemotron-H-8B-Base-8K", + # Note: hf implementation is currently broken for this model, has been + # fixed on main pending release. + # see: https://github.com/huggingface/transformers/pull/39033 + "ibm-granite/granite-4.0-tiny-preview", ] V1_SUPPORTED_MODELS = [ @@ -45,12 +52,14 @@ "ibm-ai-platform/Bamba-9B-v1", "Zyphra/Zamba2-1.2B-instruct", "nvidia/Nemotron-H-8B-Base-8K", + "ibm-granite/granite-4.0-tiny-preview", ] ATTN_BLOCK_SIZES = { "ibm-ai-platform/Bamba-9B-v1": 528, "Zyphra/Zamba2-1.2B-instruct": 80, "nvidia/Nemotron-H-8B-Base-8K": 528, + "ibm-granite/granite-4.0-tiny-preview": 400, } # Avoid OOM @@ -70,10 +79,7 @@ def test_models( num_logprobs: int, ) -> None: with hf_runner(model) as hf_model: - if model not in [ - "mistralai/Mamba-Codestral-7B-v0.1", - "nvidia/Nemotron-H-8B-Base-8K" - ]: + if model not in HF_UNSUPPORTED_MODELS: hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) else: diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 26b5b3ac1534..b3081ac9436e 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -9,6 +9,7 @@ from torch import nn from transformers import GraniteMoeHybridConfig +from vllm import envs from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size @@ -35,7 +36,7 @@ from .granitemoe import GraniteMoeMoE from .granitemoeshared import GraniteMoeSharedMLP from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, - SupportsQuant, SupportsV0Only) + SupportsQuant) from .utils import (AutoWeightsLoader, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -65,7 +66,9 @@ def __init__(self, head_dim=config.mamba_d_head, rms_norm_eps=config.rms_norm_eps, activation=config.hidden_act, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.mixer", + chunk_size=config.mamba_chunk_size) self.block_sparse_moe = None if getattr(config, "num_local_experts", 0) > 0: @@ -354,10 +357,15 @@ def forward( ) -> torch.Tensor: attn_metadata = get_forward_context().attn_metadata - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.mamba_chunk_size, - attn_metadata=attn_metadata, - ) + + if not envs.VLLM_USE_V1: + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.mamba_chunk_size, + attn_metadata=attn_metadata, + ) + else: + # v1 get mamba2_metadata from forward_context + mamba2_metadata = None if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -379,7 +387,9 @@ def forward( num_attn += 1 layer_mamba_cache_params = None - if isinstance(layer, GraniteMoeHybridMambaDecoderLayer): + if isinstance( + layer, + GraniteMoeHybridMambaDecoderLayer) and mamba_cache_params: layer_mamba_cache_params = mamba_cache_params.at_layer_idx( i - num_attn) @@ -471,8 +481,7 @@ def _load_expert(n, p, name, shard_id, expert_id): class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, - SupportsPP, IsHybrid, SupportsV0Only, - SupportsQuant): + SupportsPP, IsHybrid, SupportsQuant): packed_modules_mapping = {} embedding_modules = { "embed_tokens": "input_embeddings", @@ -535,14 +544,21 @@ def forward(self, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - if self.mamba_cache is None: - num_mamba_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.model_config.dtype, num_mamba_layers, - *self._get_mamba_cache_shape()) - - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + + mamba_cache_params = None + if not envs.VLLM_USE_V1: + if self.mamba_cache is None: + num_mamba_layers = \ + self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, + LayerBlockType.mamba + ) + self.mamba_cache = MambaCacheManager( + self.vllm_config, self.model_config.dtype, + num_mamba_layers, *self._get_mamba_cache_shape()) + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + hidden_states = self.model(input_ids, positions, mamba_cache_params, intermediate_tensors, inputs_embeds) From e2c14ba8ed7f7ba09425ab58364e1d0dbfcb5a8d Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 27 Jun 2025 14:04:50 +0000 Subject: [PATCH 12/28] add support for Falcon H1 Signed-off-by: Thomas Parnell --- .../models/language/generation/test_hybrid.py | 5 +- vllm/model_executor/models/falcon_h1.py | 57 +++++++++++++------ 2 files changed, 44 insertions(+), 18 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index d4df38c879c3..c9be61a71023 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -29,7 +29,8 @@ "hmellor/tiny-random-BambaForCausalLM", "ibm-ai-platform/Bamba-9B-v1", "nvidia/Nemotron-H-8B-Base-8K", - "ibm-granite/granite-4.0-tiny-preview" + "ibm-granite/granite-4.0-tiny-preview", + "tiiuae/Falcon-H1-0.5B-Base", ] HF_UNSUPPORTED_MODELS = [ @@ -53,6 +54,7 @@ "Zyphra/Zamba2-1.2B-instruct", "nvidia/Nemotron-H-8B-Base-8K", "ibm-granite/granite-4.0-tiny-preview", + "tiiuae/Falcon-H1-0.5B-Base", ] ATTN_BLOCK_SIZES = { @@ -60,6 +62,7 @@ "Zyphra/Zamba2-1.2B-instruct": 80, "nvidia/Nemotron-H-8B-Base-8K": 528, "ibm-granite/granite-4.0-tiny-preview": 400, + "tiiuae/Falcon-H1-0.5B-Base": 800, } # Avoid OOM diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 28f257eabed0..a76e1f256e04 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -8,6 +8,7 @@ from torch import nn from transformers import FalconH1Config +from vllm import envs from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size @@ -33,8 +34,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, - SupportsV0Only) +from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP from .utils import (PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -85,6 +85,7 @@ def __init__( config: FalconH1Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -107,6 +108,8 @@ def __init__( activation=config.hidden_act, quant_config=quant_config, use_rms_norm=config.mamba_rms_norm, + prefix=f"{prefix}.mixer", + chunk_size=config.mamba_chunk_size, ) # n_groups is overridden later by `MambaMixer2` self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state @@ -316,6 +319,7 @@ def __init__( prefix: str = "", ) -> None: super().__init__() + # Instantiate the attention branch self.self_attn = FalconH1AttentionDecoderLayer( config=config, @@ -323,11 +327,18 @@ def __init__( quant_config=quant_config, prefix=prefix, ) + + # In V1 all attention/ssm layers must have + # different index in prefix + ssm_layer_idx = config.num_hidden_layers + layer_idx + ssm_prefix = prefix.split(".")[0] + f".{ssm_layer_idx}" + # Instantiate the SSM branch self.mamba = FalconH1SSMDecoderLayer( config=config, cache_config=cache_config, quant_config=quant_config, + prefix=ssm_prefix, ) self.ssm_out_multiplier = config.ssm_out_multiplier self.ssm_in_multiplier = config.ssm_in_multiplier @@ -452,10 +463,16 @@ def forward( # proper continuous batching computation including # chunked prefill attn_metadata = get_forward_context().attn_metadata - mamba2_metadata = prepare_mamba2_metadata( - chunk_size=self.config.mamba_chunk_size, - attn_metadata=attn_metadata, - ) + + if not envs.VLLM_USE_V1: + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.mamba_chunk_size, + attn_metadata=attn_metadata, + ) + else: + # v1 get mamba2_metadata from forward_context + mamba2_metadata = None + if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds * self.embedding_multiplier @@ -468,7 +485,9 @@ def forward( for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i) + layer_mamba_cache_params = None + if mamba_cache_params: + layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i) hidden_states = layer( positions=positions, hidden_states=hidden_states, @@ -484,7 +503,7 @@ def forward( class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybrid, SupportsV0Only): + IsHybrid): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], @@ -558,15 +577,19 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ): - if self.mamba_cache is None: - self.mamba_cache = MambaCacheManager( - self.vllm_config, - self.lm_head.weight.dtype - if hasattr(self.lm_head, 'weight') else torch.bfloat16, - self.config.num_hidden_layers, - *self._get_mamba_cache_shape(), - ) - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + + mamba_cache_params = None + if not envs.VLLM_USE_V1: + if self.mamba_cache is None: + self.mamba_cache = MambaCacheManager( + self.vllm_config, + self.lm_head.weight.dtype if hasattr( + self.lm_head, 'weight') else torch.bfloat16, + self.config.num_hidden_layers, + *self._get_mamba_cache_shape(), + ) + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + hidden_states = self.model( input_ids, positions, From c5a25ebf3ce749ccf446312c844dda59b52128fa Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 30 Jun 2025 08:15:25 +0000 Subject: [PATCH 13/28] Fix overflow issue in mamba_ssm kernel Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/mamba_ssm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index ccfb278cdff6..3f67fc35afdf 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -108,7 +108,7 @@ def _selective_scan_update_kernel( # is the same as the batch id. if HAS_STATE_BATCH_INDICES: state_batch_indices_ptr += pid_b - state_batch_idx = tl.load(state_batch_indices_ptr) + state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64) state_ptr += (state_batch_idx * stride_state_batch + pid_h * stride_state_head) else: From aaa6f0e0cf9f30edfb83f861ac9b3abacd22cad0 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 30 Jun 2025 11:49:25 +0000 Subject: [PATCH 14/28] Add check for transformers min version Co-authored-by: Stanislaw Wozniak Signed-off-by: Thomas Parnell --- tests/models/language/generation/test_hybrid.py | 6 ++++++ tests/models/registry.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 27080b3cb169..0c2f709e1871 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -3,6 +3,7 @@ import pytest +from tests.models.registry import HF_EXAMPLE_MODELS from tests.utils import multi_gpu_test from vllm.engine.arg_utils import EngineArgs from vllm.sampling_params import SamplingParams @@ -84,6 +85,11 @@ def test_models( max_tokens: int, num_logprobs: int, ) -> None: + + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + with hf_runner(model) as hf_model: if model not in HF_UNSUPPORTED_MODELS: hf_outputs = hf_model.generate_greedy_logprobs_limit( diff --git a/tests/models/registry.py b/tests/models/registry.py index e56dd19bec67..96696b701e9d 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -165,7 +165,7 @@ def check_available_online( "ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501 "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501 "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), - "FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-1.5B-Instruct", + "FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base", min_transformers_version="4.53"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), From d187bfdf2c8f4b38b3eac150cdd50dc5361859eb Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 30 Jun 2025 13:17:50 +0000 Subject: [PATCH 15/28] Don't fail test if model is not in HF_EXAMPLE_MODELS Signed-off-by: Thomas Parnell --- tests/models/language/generation/test_hybrid.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 0c2f709e1871..0c2aac63a7f1 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -86,9 +86,12 @@ def test_models( num_logprobs: int, ) -> None: - model_info = HF_EXAMPLE_MODELS.find_hf_info(model) - model_info.check_available_online(on_fail="skip") - model_info.check_transformers_version(on_fail="skip") + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass with hf_runner(model) as hf_model: if model not in HF_UNSUPPORTED_MODELS: From fde28dc09541a9b4a251c9753d25bb8c0daa4ec6 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 30 Jun 2025 15:02:23 +0000 Subject: [PATCH 16/28] Fix test_batching in same way Signed-off-by: Thomas Parnell --- tests/models/language/generation/test_hybrid.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 0c2aac63a7f1..1123bb5c23ad 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -153,6 +153,14 @@ def test_batching( max_tokens: int, num_logprobs: int, ) -> None: + + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + for_loop_outputs = [] with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: for prompt in example_prompts: From 1777fd1adb730b3b538bb6838ba125eac99dcfbe Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 1 Jul 2025 21:44:07 +0200 Subject: [PATCH 17/28] Update tests/models/language/generation/test_hybrid.py typo Co-authored-by: Tyler Michael Smith Signed-off-by: Thomas Parnell --- tests/models/language/generation/test_hybrid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 1123bb5c23ad..ecaae3ec1fc4 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -46,7 +46,7 @@ # Note: I'm not seeing the same output from vLLM V0 vs. HF transformers # for Nemotron-H-8B; currently only compare vLLM V0 vs. vLLM V1 "nvidia/Nemotron-H-8B-Base-8K", - # NOTE: Currently the test failes due to HF transformers issue fixed in: + # NOTE: Currently the test fails due to HF transformers issue fixed in: # https://github.com/huggingface/transformers/pull/39033 # We will enable vLLM test for Granite after next HF transformers release. "ibm-granite/granite-4.0-tiny-preview", From 58e66c95f5bc5cc6aac82f46a284891407498d7a Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 1 Jul 2025 22:10:39 +0200 Subject: [PATCH 18/28] Update vllm/model_executor/models/granitemoehybrid.py Co-authored-by: Chen Zhang Signed-off-by: Thomas Parnell --- vllm/model_executor/models/granitemoehybrid.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 065f82a8582b..08c99f1ce91f 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -584,11 +584,11 @@ def forward(self, mamba_cache_params = None if not envs.VLLM_USE_V1: if self.mamba_cache is None: - num_mamba_layers = \ + num_mamba_layers = ( self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, LayerBlockType.mamba - ) + )) self.mamba_cache = MambaCacheManager( self.vllm_config, self.model_config.dtype, num_mamba_layers, *self._get_mamba_cache_shape()) From c2da03e809b0d0036894b1ecb318756bb5a5580e Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 1 Jul 2025 20:15:41 +0000 Subject: [PATCH 19/28] page_size -> num_element_per_page Signed-off-by: Thomas Parnell --- vllm/v1/worker/gpu_model_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 72ca0144b60f..1eea734d1908 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2444,14 +2444,14 @@ def _reshape_kv_cache_tensors( elif isinstance(kv_cache_spec, MambaSpec): raw_tensor = kv_cache_raw_tensors[layer_name] dtype = kv_cache_spec.dtype - page_size = kv_cache_spec.page_size_bytes // get_dtype_size( - dtype) + num_element_per_page = (kv_cache_spec.page_size_bytes // + get_dtype_size(dtype)) state_tensors = [] storage_offset = 0 for shape in kv_cache_spec.shapes: target_shape = (num_blocks, *shape) stride = torch.empty(target_shape).stride() - target_stride = (page_size, *stride[1:]) + target_stride = (num_element_per_page, *stride[1:]) tensor = torch.as_strided( raw_tensor.view(dtype), size=target_shape, From e0404c936c1cb2da6ed69f4657f485fc398aa303 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 1 Jul 2025 20:44:25 +0000 Subject: [PATCH 20/28] Clean up page size padding logic Signed-off-by: Thomas Parnell --- vllm/v1/kv_cache_interface.py | 9 +++--- vllm/v1/worker/gpu_model_runner.py | 49 ++++++++++++++++-------------- 2 files changed, 32 insertions(+), 26 deletions(-) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 31e2ac23f5e4..43456a987def 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -11,7 +11,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.utils import cdiv, get_dtype_size, round_up +from vllm.utils import cdiv, get_dtype_size logger = init_logger(__name__) @@ -159,7 +159,7 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: class MambaSpec(KVCacheSpec): shapes: tuple[tuple[int, ...], ...] dtype: torch.dtype - multiple_of: Optional[int] + page_size_padded: Optional[int] = None def __post_init__(self): self.num_elements = sum(prod(shape) for shape in self.shapes) @@ -171,8 +171,9 @@ def type_id(self) -> str: @property def page_size_bytes(self) -> int: page_size = self.num_elements * get_dtype_size(self.dtype) - if self.multiple_of is not None: - page_size = round_up(page_size, self.multiple_of) + if self.page_size_padded is not None: + assert self.page_size_padded >= page_size + return self.page_size_padded return page_size def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1eea734d1908..041af0b1225e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2575,15 +2575,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaMixer2) if len(mamba_layers) > 0: - if len(attn_layers) > 0: - # Mamba state must be padded to an integer number of - # 16th tokens worth of attention pages - attn_layer_name = next(iter(attn_layers)) - attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes - multiple_of = 16 * attn_page_size // block_size - else: - multiple_of = None - if self.vllm_config.speculative_config is not None: raise NotImplementedError( "Mamba with speculative decoding is not supported yet.") @@ -2594,6 +2585,32 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: raise NotImplementedError( "Prefix caching is not supported for Mamba yet.") max_model_len = self.vllm_config.model_config.max_model_len + + if len(attn_layers) > 0: + attn_layer_name = next(iter(attn_layers)) + attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes + mamba_layer_name = next(iter(mamba_layers)) + mamba_page_size = MambaSpec( + shapes=mamba_layers[mamba_layer_name].get_state_shape(), + dtype=self.kv_cache_dtype, + block_size=max_model_len).page_size_bytes + if attn_page_size < mamba_page_size: + # attention page size (for 16 tokens) + attn_page_size_16 = 16 * attn_page_size // block_size + # some attention backends (e.g. FA) only support setting + # block size to multiple of 16, so let's suggest a value + # that would work (note: FA is currently not compatible + # with mamba layers, use FlashInfer instead). + suggest_attn_block_size = 16 * cdiv( + mamba_page_size, attn_page_size_16) + raise ValueError( + "Attention block size should be increased to at least " + f"{suggest_attn_block_size} in order to match " + "the mamba page size") + page_size_padded = attn_page_size + else: + page_size_padded = None + # Set block_size to max_model_len, so that mamba model will always # have only one block in the KV cache. for layer_name, mamba_module in mamba_layers.items(): @@ -2601,18 +2618,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: shapes=mamba_module.get_state_shape(), dtype=self.kv_cache_dtype, block_size=max_model_len, - multiple_of=multiple_of) - - if len(attn_layers) > 0: - mamba_layer_name = next(iter(mamba_layers)) - mamba_page_size = kv_cache_spec[ - mamba_layer_name].page_size_bytes - if attn_page_size < mamba_page_size: - required_attn_block_size = cdiv(mamba_page_size, - multiple_of) * 16 - raise ValueError( - "Attention block size must be increased to " - f"{required_attn_block_size} in order to match " - "the mamba page size") + page_size_padded=page_size_padded) return kv_cache_spec From c74698d8c7aaa759d6ed8ddeb4ce087da1f767cd Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 1 Jul 2025 20:52:34 +0000 Subject: [PATCH 21/28] gpu_model_runner.py: add TODO about batch reordering Signed-off-by: Thomas Parnell --- vllm/v1/worker/gpu_model_runner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 041af0b1225e..d33b537f2b40 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -336,6 +336,8 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool: # the same order of requests. We ensure this by only allowing the first # group to reorder the batch and asserting that all other groups do not # reorder the batch. + # TODO(tdoublep): make this more flexible so that any group can + # re-order the batch (not only the first). for i in range(1, len(self.kv_cache_config.kv_cache_groups)): assert not self.attn_metadata_builders[i].reorder_batch( self.input_batch, scheduler_output) From b72b729eb1d2bc1720bc5955f7d2001e8ad0f235 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 1 Jul 2025 21:01:37 +0000 Subject: [PATCH 22/28] Fix linting issue Signed-off-by: Thomas Parnell --- vllm/model_executor/models/granitemoehybrid.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 08c99f1ce91f..676ef24fc4da 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -587,8 +587,7 @@ def forward(self, num_mamba_layers = ( self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, - LayerBlockType.mamba - )) + LayerBlockType.mamba)) self.mamba_cache = MambaCacheManager( self.vllm_config, self.model_config.dtype, num_mamba_layers, *self._get_mamba_cache_shape()) From 105737cee695fde5b46fd304f7384cf5db026045 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 1 Jul 2025 21:32:27 +0000 Subject: [PATCH 23/28] Validate memory layout for hybrid models against attention backends Signed-off-by: Thomas Parnell --- vllm/v1/worker/gpu_model_runner.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d33b537f2b40..d97bd3bfa400 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2407,6 +2407,7 @@ def _reshape_kv_cache_tensors( corresponding memory buffer for KV cache. """ kv_caches: dict[str, torch.Tensor] = {} + needs_validation = False for i, kv_cache_group_spec in enumerate( kv_cache_config.kv_cache_groups): kv_cache_spec = kv_cache_group_spec.kv_cache_spec @@ -2444,6 +2445,7 @@ def _reshape_kv_cache_tensors( layer_name].view(dtype).view(kv_cache_shape).permute( *inv_order) elif isinstance(kv_cache_spec, MambaSpec): + needs_validation = True raw_tensor = kv_cache_raw_tensors[layer_name] dtype = kv_cache_spec.dtype num_element_per_page = (kv_cache_spec.page_size_bytes // @@ -2467,6 +2469,28 @@ def _reshape_kv_cache_tensors( else: raise NotImplementedError + # Validate layout for hybrid models + if needs_validation: + for i, kv_cache_group_spec in enumerate( + kv_cache_config.kv_cache_groups): + kv_cache_spec = kv_cache_group_spec.kv_cache_spec + for layer_name in kv_cache_group_spec.layer_names: + raw_tensor = kv_cache_raw_tensors[layer_name] + num_blocks = (raw_tensor.numel() // + kv_cache_spec.page_size_bytes) + if isinstance(kv_cache_spec, AttentionSpec): + kv_cache_shape = self.attn_backends[ + i].get_kv_cache_shape(num_blocks, + kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size) + if kv_cache_shape[0] != num_blocks or kv_cache_shape[ + 1] != 2: + raise ValueError( + "Hybrid models in V1 require an attention " + "backend with kv_cache_shape=" + "(num_blocks, 2, ...). Please try setting " + "VLLM_ATTENTION_BACKEND=FLASHINFER") return kv_caches def initialize_kv_cache_tensors( From d8ff3b900f030f3d87cce7846f32039881a27d82 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 1 Jul 2025 22:13:22 +0000 Subject: [PATCH 24/28] Adjust comment Signed-off-by: Thomas Parnell --- vllm/model_executor/models/zamba2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index c273f70777d2..54c80cfa5922 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -704,7 +704,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: # Initialize layers according to block type configuration layers = [] for layer_idx, layer_type in enumerate(config.layers_block_type): - # tpa: avoid layers getting same index + # tdoublep: avoid layers getting same index # somewhat hacky but correct (I think) prefix = str(len(layer2block_map) + layer_idx) if layer_type == "hybrid": From ea8cf32b885eac956cd0610fc03edbb12fec5b0d Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 4 Jul 2025 12:45:21 +0000 Subject: [PATCH 25/28] Move memory layout check into separate function Signed-off-by: Thomas Parnell --- vllm/v1/worker/gpu_model_runner.py | 62 ++++++++++++++++++------------ 1 file changed, 38 insertions(+), 24 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 268db5d804fc..5bec34667a46 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2451,7 +2451,7 @@ def _reshape_kv_cache_tensors( corresponding memory buffer for KV cache. """ kv_caches: dict[str, torch.Tensor] = {} - needs_validation = False + has_attn, has_mamba = False, False for i, kv_cache_group_spec in enumerate( kv_cache_config.kv_cache_groups): kv_cache_spec = kv_cache_group_spec.kv_cache_spec @@ -2461,6 +2461,7 @@ def _reshape_kv_cache_tensors( num_blocks = (raw_tensor.numel() // kv_cache_spec.page_size_bytes) if isinstance(kv_cache_spec, AttentionSpec): + has_attn = True kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) @@ -2489,7 +2490,7 @@ def _reshape_kv_cache_tensors( layer_name].view(dtype).view(kv_cache_shape).permute( *inv_order) elif isinstance(kv_cache_spec, MambaSpec): - needs_validation = True + has_mamba = True raw_tensor = kv_cache_raw_tensors[layer_name] dtype = kv_cache_spec.dtype num_element_per_page = (kv_cache_spec.page_size_bytes // @@ -2513,30 +2514,43 @@ def _reshape_kv_cache_tensors( else: raise NotImplementedError - # Validate layout for hybrid models - if needs_validation: - for i, kv_cache_group_spec in enumerate( - kv_cache_config.kv_cache_groups): - kv_cache_spec = kv_cache_group_spec.kv_cache_spec - for layer_name in kv_cache_group_spec.layer_names: - raw_tensor = kv_cache_raw_tensors[layer_name] - num_blocks = (raw_tensor.numel() // - kv_cache_spec.page_size_bytes) - if isinstance(kv_cache_spec, AttentionSpec): - kv_cache_shape = self.attn_backends[ - i].get_kv_cache_shape(num_blocks, - kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size) - if kv_cache_shape[0] != num_blocks or kv_cache_shape[ - 1] != 2: - raise ValueError( - "Hybrid models in V1 require an attention " - "backend with kv_cache_shape=" - "(num_blocks, 2, ...). Please try setting " - "VLLM_ATTENTION_BACKEND=FLASHINFER") + if has_attn and has_mamba: + self._verify_hybrid_attention_mamba_layout(kv_cache_config, + kv_cache_raw_tensors) + return kv_caches + def _verify_hybrid_attention_mamba_layout( + self, kv_cache_config: KVCacheConfig, + kv_cache_raw_tensors: dict[str, torch.Tensor]) -> None: + """ + Verify that the KV cache memory layout is compatible for + models with both attention and mamba KV cache groups. + + Args: + kv_cache_config: The KV cache config + kv_cache_raw_tensors: The KV cache buffer of each layer. + """ + + for i, kv_cache_group_spec in enumerate( + kv_cache_config.kv_cache_groups): + kv_cache_spec = kv_cache_group_spec.kv_cache_spec + for layer_name in kv_cache_group_spec.layer_names: + raw_tensor = kv_cache_raw_tensors[layer_name] + num_blocks = (raw_tensor.numel() // + kv_cache_spec.page_size_bytes) + if isinstance(kv_cache_spec, AttentionSpec): + kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + if kv_cache_shape[0] != num_blocks or kv_cache_shape[ + 1] != 2: + raise ValueError( + "Hybrid models in V1 require an attention " + "backend with kv_cache_shape=" + "(num_blocks, 2, ...). Please try setting " + "VLLM_ATTENTION_BACKEND=FLASHINFER") + def initialize_kv_cache_tensors( self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: """ From b38d3fba402df336975af3b75440cd9d67771c19 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 4 Jul 2025 13:09:56 +0000 Subject: [PATCH 26/28] Move logic to pad mamba page size into separate function Signed-off-by: Thomas Parnell --- vllm/v1/worker/gpu_model_runner.py | 78 +++++++++++++++++++++--------- 1 file changed, 54 insertions(+), 24 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5bec34667a46..9c3fb01e04e8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2670,30 +2670,9 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: "Prefix caching is not supported for Mamba yet.") max_model_len = self.vllm_config.model_config.max_model_len - if len(attn_layers) > 0: - attn_layer_name = next(iter(attn_layers)) - attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes - mamba_layer_name = next(iter(mamba_layers)) - mamba_page_size = MambaSpec( - shapes=mamba_layers[mamba_layer_name].get_state_shape(), - dtype=self.kv_cache_dtype, - block_size=max_model_len).page_size_bytes - if attn_page_size < mamba_page_size: - # attention page size (for 16 tokens) - attn_page_size_16 = 16 * attn_page_size // block_size - # some attention backends (e.g. FA) only support setting - # block size to multiple of 16, so let's suggest a value - # that would work (note: FA is currently not compatible - # with mamba layers, use FlashInfer instead). - suggest_attn_block_size = 16 * cdiv( - mamba_page_size, attn_page_size_16) - raise ValueError( - "Attention block size should be increased to at least " - f"{suggest_attn_block_size} in order to match " - "the mamba page size") - page_size_padded = attn_page_size - else: - page_size_padded = None + page_size_padded = self._maybe_pad_mamba_page_size( + attn_layers, mamba_layers, kv_cache_spec, max_model_len, + block_size) # Set block_size to max_model_len, so that mamba model will always # have only one block in the KV cache. @@ -2705,3 +2684,54 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: page_size_padded=page_size_padded) return kv_cache_spec + + def _maybe_pad_mamba_page_size( + self, + attn_layers: dict[str, Attention], + mamba_layers: dict[str, MambaMixer2], + kv_cache_spec: dict[str, KVCacheSpec], + max_model_len: int, + block_size: int, + ) -> Optional[int]: + """ + Ensure that page size of attention KV cache groups is greater than or + equal to the mamba KV cache groups. If not, we suggest to the user + how to set the attention block size to ensure that it is. + + If the attention page size is strictly greater than the mamba page size, + we pad the mamba page size to make them equal. + + Args: + attn_layers: Attention layers + mamba_layers: Mamba layers + kv_cache_spec: KV cache spec (populated with attention layers) + + Returns: + Optional[int]: Mamba page size with padding (None if no padding). + """ + + if len(attn_layers) == 0: + return None + + attn_layer_name = next(iter(attn_layers)) + attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes + mamba_layer_name = next(iter(mamba_layers)) + mamba_page_size = MambaSpec( + shapes=mamba_layers[mamba_layer_name].get_state_shape(), + dtype=self.kv_cache_dtype, + block_size=max_model_len).page_size_bytes + if attn_page_size < mamba_page_size: + # attention page size (for 16 tokens) + attn_page_size_16 = 16 * attn_page_size // block_size + # some attention backends (e.g. FA) only support setting + # block size to multiple of 16, so let's suggest a value + # that would work (note: FA is currently not compatible + # with mamba layers, use FlashInfer instead). + suggest_attn_block_size = 16 * cdiv(mamba_page_size, + attn_page_size_16) + raise ValueError( + "Attention block size should be increased to at least " + f"{suggest_attn_block_size} in order to match " + "the mamba page size") + + return attn_page_size From e6b001504cc7933013a88ad3e00fd5756cc070fd Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 4 Jul 2025 13:11:48 +0000 Subject: [PATCH 27/28] Add extra todo Signed-off-by: Thomas Parnell --- vllm/v1/worker/gpu_model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9c3fb01e04e8..57d0c7b50ff5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -336,6 +336,7 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: # reorder the batch. # TODO(tdoublep): make this more flexible so that any group can # re-order the batch (not only the first). + # TODO(tdoublep): verify this during engine init instead of at runtime for i in range(1, len(self.kv_cache_config.kv_cache_groups)): batch_reordered = self.attn_metadata_builders[i].reorder_batch( self.input_batch, scheduler_output) From 14fd0065a6753b1698e6ac9afa58b7ee9be9fd7b Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 4 Jul 2025 15:09:45 +0000 Subject: [PATCH 28/28] test_oracle.py: hybrid models now supported Signed-off-by: Thomas Parnell --- tests/v1/test_oracle.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index d640d7dc49d1..7a7ba346a719 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -13,7 +13,6 @@ "openai/whisper-large-v3", # transcription "facebook/bart-large-cnn", # encoder decoder "state-spaces/mamba-130m-hf", # mamba1 - "hmellor/tiny-random-BambaForCausalLM", # hybrid "BAAI/bge-m3", # embedding ]