diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py new file mode 100644 index 00000000000..b4e0101a0d4 --- /dev/null +++ b/tests/v1/attention/test_attention_backends.py @@ -0,0 +1,466 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for v1 attention backends without GPUModelRunner dependency.""" + +import pytest +import torch + +from tests.v1.attention.utils import (BatchSpec, _Backend, + create_common_attn_metadata, + create_standard_kv_cache_spec, + create_vllm_config, + get_attention_backend) +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import FullAttentionSpec + +BACKENDS_TO_TEST = [ + _Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1, + _Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1 +] + +# Remove flashinfer from the list if it's not available +try: + import flashinfer # noqa: F401 +except ImportError: + BACKENDS_TO_TEST.remove(_Backend.FLASHINFER_VLLM_V1) + + +def _convert_dtype_to_torch(dtype): + """Convert ModelDType to torch.dtype.""" + if isinstance(dtype, str): + if dtype == "auto": + return torch.float16 # Default dtype for testing + elif dtype in STR_DTYPE_TO_TORCH_DTYPE: + return STR_DTYPE_TO_TORCH_DTYPE[dtype] + else: + raise ValueError(f"Unknown dtype: {dtype}") + elif isinstance(dtype, torch.dtype): + return dtype + else: + raise ValueError(f"Unknown dtype: {dtype}") + + +# Define common batch configurations +BATCH_SPECS = { + "small_decode": + BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]), + "small_prefill": + BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]), + "mixed_small": + BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]), + "medium_decode": + BatchSpec(seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024], + query_lens=[1, 1, 1, 1, 1, 1, 1, 1]), + "medium_prefill": + BatchSpec(seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]), + "mixed_medium": + BatchSpec(seq_lens=[512, 1024, 2048, 512, 1024, 2048], + query_lens=[1, 1, 1, 7, 7, 7]), + "large_decode": + BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32), + "large_prefill": + BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8), + "single_decode": + BatchSpec(seq_lens=[1024], query_lens=[1]), + "single_prefill": + BatchSpec(seq_lens=[1024], query_lens=[64]), +} + + +def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec, + device: torch.device, + num_blocks: int = 100) -> torch.Tensor: + """Create a dummy KV cache tensor for testing.""" + kv_cache = torch.randn( + 2, # K and V + num_blocks, + kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + dtype=_convert_dtype_to_torch(kv_cache_spec.dtype), + device=device, + ) + return kv_cache + + +def create_and_prepopulate_kv_cache( + k_contexts: list[torch.Tensor], + v_contexts: list[torch.Tensor], + block_size: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + num_blocks: int, + common_attn_metadata: CommonAttentionMetadata, + randomize_blocks: bool = True) -> torch.Tensor: + """Create and prepopulate a KV cache with context data. + + Args: + k_contexts: List of key context tensors for each sequence + v_contexts: List of value context tensors for each sequence + seq_lens: List of sequence lengths + block_size: Size of each block + num_kv_heads: Number of KV heads + head_size: Size of each head + dtype: Data type for the cache + device: Device to create the cache on + num_blocks: Total number of blocks in the cache + block_table: Block table tensor to populate + randomize_blocks: Whether to randomly permute blocks + or use sequential order + + Returns: + Tuple of (kv_cache, updated_block_table) + """ + batch_size = len(k_contexts) + seq_lens = common_attn_metadata.seq_lens_cpu + query_lens = common_attn_metadata.query_start_loc_cpu[ + 1:] - common_attn_metadata.query_start_loc_cpu[:-1] + context_lens = common_attn_metadata.num_computed_tokens_cpu + block_table = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + + # Create KV cache + kv_cache = torch.empty(2, + num_blocks, + block_size, + num_kv_heads, + head_size, + dtype=dtype, + device=device) + kv_cache_flat = kv_cache.view(2, -1, num_kv_heads, head_size) + + # Populate the cache with the context tokens + # Start from block_id=1 since block_id=0 is considered the null block + start_block_idx = 1 + for i in range(batch_size): + k_context, v_context = k_contexts[i], v_contexts[i] + start = start_block_idx * block_size + end = start + k_context.shape[0] + kv_cache_flat[0, start:end, ...] = k_context + kv_cache_flat[1, start:end, ...] = v_context + + # Stay block aligned and allocate enough blocks for the new tokens + start_block_idx += cdiv(int(seq_lens[i]), block_size) + + blocks_end = start_block_idx + + # Permute the context blocks (excluding block 0 which is null) + if randomize_blocks: + perm = torch.randperm( + blocks_end - 1) + 1 # Random permutation starting from block 1 + else: + perm = torch.arange( + 1, blocks_end) # Sequential order starting from block 1 + + inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device) + inv_perm[1:] = torch.argsort( + perm) + 1 # Add 1 to account for starting from block 1 + kv_cache[:, 1:blocks_end, ...] = kv_cache[:, perm, ...] + + # Construct the right block table + # Start from block_id=1 since block_id=0 is considered the null block + start_block_idx = 1 + for i in range(batch_size): + num_blocks_for_seq = cdiv(int(seq_lens[i]), block_size) + start = start_block_idx + end = start + num_blocks_for_seq + block_table[i, :num_blocks_for_seq] = inv_perm[start:end] + start_block_idx += num_blocks_for_seq + + # Create a realistic slot mapping that corresponds to the block table + for i in range(batch_size): + token_offsets = torch.arange(int(query_lens[i])) + int(context_lens[i]) + block_indices = token_offsets // block_size + token_inter_block_offsets = token_offsets % block_size + start = common_attn_metadata.query_start_loc_cpu[i] + end = common_attn_metadata.query_start_loc_cpu[i + 1] + slot_mapping[start:end] = block_table[ + i, + block_indices] * block_size + token_inter_block_offsets.to(device) + + return kv_cache + + +class MockAttentionLayer: + """A mock attention layer for testing.""" + + def __init__(self, device: torch.device): + self._q_scale = torch.tensor(1.0, device=device) + self._k_scale = torch.tensor(1.0, device=device) + self._v_scale = torch.tensor(1.0, device=device) + # Add float versions for flashinfer + self._k_scale_float = 1.0 + self._v_scale_float = 1.0 + + +def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, + vllm_config, device: torch.device, + common_attn_metadata: CommonAttentionMetadata, + query: torch.Tensor, key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor) -> torch.Tensor: + """Run attention computation using the specified backend's AttentionImpl.""" + + builder_cls, impl_cls = get_attention_backend(backend) + + # Mock flashinfer's get_per_layer_parameters if needed + if backend == _Backend.FLASHINFER_VLLM_V1: + import unittest.mock + + from vllm.v1.attention.backends.flashinfer import PerLayerParameters + + def mock_get_per_layer_parameters(vllm_config): + # Return mock parameters for a single layer + head_size = vllm_config.model_config.get_head_size() + return { + "mock_layer": + PerLayerParameters( + window_left=-1, # No sliding window + logits_soft_cap=0.0, # No soft cap + sm_scale=1.0 / (head_size**0.5) # Standard scale + ) + } + + with unittest.mock.patch( + 'vllm.v1.attention.backends.flashinfer.get_per_layer_parameters', + mock_get_per_layer_parameters): + builder = builder_cls(kv_cache_spec, vllm_config, device) + attn_metadata = builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + else: + # Build metadata + builder = builder_cls(kv_cache_spec, vllm_config, device) + attn_metadata = builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + + # Instantiate implementation + num_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config) + num_kv_heads = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config) + head_size = vllm_config.model_config.get_head_size() + scale = 1.0 / (head_size**0.5) + impl = impl_cls( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="auto", + ) + + # Create mock layer and output buffer + mock_layer = MockAttentionLayer(device) + output = torch.empty_like(query) + + # Run forward pass + # NOTE: The query, key, and value are already shaped correctly + # in the calling test function. + output = impl.forward(mock_layer, + query, + key, + value, + kv_cache, + attn_metadata, + output=output) + + return output + + +@pytest.mark.parametrize("batch_spec_name", [ + "small_decode", "small_prefill", "mixed_small", "medium_decode", + "medium_prefill", "mixed_medium" +]) +@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) +def test_backend_correctness(batch_spec_name: str, model: str): + """ + Test that all backends produce similar outputs to a reference implementation + using torch.nn.functional.scaled_dot_product_attention. + + This test works by: + 1. Generating a batch of sequences with specified context and query lengths. + 2. Computing a ground-truth attention output using torch.sdpa on + contiguous Q, K, and V tensors. + 3. Simulating vLLM's paged KV cache: It takes the context portion of the + K/V tensors and manually places them into a paged buffer according to + the test's (randomly generated) block table. + 4. Running each vLLM attention backend with the new queries and the + simulated paged KV cache. + 5. Comparing the vLLM backend's output to the ground-truth SDPA output. + """ + batch_spec = BATCH_SPECS[batch_spec_name] + vllm_config = create_vllm_config(model_name=model) + device = torch.device("cuda:0") + + kv_cache_spec = create_standard_kv_cache_spec(vllm_config) + + # 1. Setup + batch_size = batch_spec.batch_size + seq_lens = batch_spec.seq_lens + query_lens = batch_spec.query_lens + num_q_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config) + num_kv_heads = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config) + head_size = vllm_config.model_config.get_head_size() + dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) + block_size = vllm_config.cache_config.block_size + scale = 1.0 / (head_size**0.5) + + # 2. Generate data and compute SDPA reference output + all_q_vllm, all_k_vllm, all_v_vllm = [], [], [] + all_sdpa_outputs = [] + k_contexts, v_contexts = [], [] + + for i in range(batch_size): + s_len = seq_lens[i] + q_len = query_lens[i] + context_len = s_len - q_len + + # Generate Q, K, V for the whole sequence to be used in SDPA + q = torch.randn(q_len, + num_q_heads, + head_size, + dtype=dtype, + device=device) + k_full = torch.randn(s_len, + num_kv_heads, + head_size, + dtype=dtype, + device=device) + v_full = torch.randn(s_len, + num_kv_heads, + head_size, + dtype=dtype, + device=device) + + # SDPA expects (N, H, L, D), so unsqueeze batch and permute + q_sdpa_in = q.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2) + + if num_q_heads != num_kv_heads: + assert num_q_heads % num_kv_heads == 0, ( + f"num_q_heads ({num_q_heads}) must be divisible by " + f"num_kv_heads ({num_kv_heads})") + repeats = num_q_heads // num_kv_heads + k_sdpa_in = k_sdpa_in.repeat_interleave(repeats, dim=1) + v_sdpa_in = v_sdpa_in.repeat_interleave(repeats, dim=1) + + # Create causal mask: query token i attends to positions 0 to + # (context_len + i) + kv_len = s_len + offset = context_len + attn_mask = torch.full((q_len, kv_len), + float('-inf'), + device=device, + dtype=dtype) + for i in range(q_len): + attn_mask[i, :offset + i + 1] = 0.0 + + sdpa_out_i = torch.nn.functional.scaled_dot_product_attention( + q_sdpa_in, + k_sdpa_in, + v_sdpa_in, + attn_mask=attn_mask, + scale=scale, + enable_gqa=True) + # Convert back to (L, H, D) + all_sdpa_outputs.append(sdpa_out_i.transpose(1, 2).squeeze(0)) + + # Inputs for vLLM backends are just the new tokens + all_q_vllm.append(q) + all_k_vllm.append(k_full[context_len:]) + all_v_vllm.append(v_full[context_len:]) + + # Contextual K/V data used to populate the paged cache + k_contexts.append(k_full[:context_len]) + v_contexts.append(v_full[:context_len]) + + query_vllm = torch.cat(all_q_vllm, dim=0) + key_vllm = torch.cat(all_k_vllm, dim=0) + value_vllm = torch.cat(all_v_vllm, dim=0) + sdpa_output = torch.cat(all_sdpa_outputs, dim=0) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, vllm_config.cache_config.block_size, device) + + # 3. Simulate Paged KV Cache and a realistic slot_mapping + kv_cache = create_and_prepopulate_kv_cache( + k_contexts=k_contexts, + v_contexts=v_contexts, + block_size=block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + device=device, + num_blocks=vllm_config.cache_config.num_gpu_blocks or 1000, + common_attn_metadata=common_attn_metadata, + randomize_blocks=True) + + # 4. Run vLLM backends and compare + # Note: flex_attention has known Triton kernel compatibility issues + # with test infrastructures + for backend_name in BACKENDS_TO_TEST: + # FlashAttentionm + FlexAttention: + # [2, num_blocks, block_size, num_kv_heads, head_size] + # FlashInfer: + # [num_blocks, 2, block_size, num_kv_heads, head_size] + # Select the appropriate KV cache format for each backend + kv_cache_for_backend = kv_cache + if backend_name == _Backend.FLASHINFER_VLLM_V1: + kv_cache_for_backend = kv_cache.transpose(0, 1) + + backend_output = run_attention_backend(backend_name, kv_cache_spec, + vllm_config, device, + common_attn_metadata, + query_vllm, key_vllm, + value_vllm, + kv_cache_for_backend) + + # Check shape and dtype consistency + assert backend_output.shape == sdpa_output.shape, ( + f"[{backend_name}] shape {backend_output.shape} != " + f"SDPA shape {sdpa_output.shape}") + assert backend_output.dtype == sdpa_output.dtype, ( + f"[{backend_name}] dtype {backend_output.dtype} != " + f"SDPA dtype {sdpa_output.dtype}") + + assert torch.isfinite(backend_output).all(), ( + f"[{backend_name}] produced non-finite values") + + # Check numerical similarity + rtol = 1e-2 + atol = 5e-3 + + if backend_name == _Backend.FLEX_ATTENTION: + atol = 5e-1 # TODO: figure out why flex_attention has such large + # numerical differences for medium_decode, medium_prefill, + # mixed_medium + + max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item() + max_rel_diff = torch.max( + torch.abs(backend_output - sdpa_output) / + torch.abs(sdpa_output)).item() + all_close = torch.allclose(backend_output, + sdpa_output, + rtol=rtol, + atol=atol) + + if not all_close: + print(f"[{backend_name}] output differs from SDPA baseline. " + f"Max diff: {max_diff:.6f} (rel: {max_rel_diff:.6f})") + print(f"[{backend_name}] output: {backend_output}") + print(f"[{backend_name}] SDPA baseline: {sdpa_output}") + + assert all_close, ( + f"[{backend_name}] output differs from SDPA baseline. " + f"Max diff: {max_diff:.6f} (rel: {max_rel_diff:.6f})") diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py new file mode 100644 index 00000000000..30cfbdda5d8 --- /dev/null +++ b/tests/v1/attention/utils.py @@ -0,0 +1,229 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Utility functions for attention-related v1 tests.""" + +from dataclasses import dataclass +from typing import Union + +import pytest +import torch + +from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig, + LoadConfig, ModelConfig, ModelDType, ParallelConfig, + SchedulerConfig, VllmConfig) +from vllm.platforms import _Backend +from vllm.utils import resolve_obj_by_qualname +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import FullAttentionSpec + + +@dataclass +class BatchSpec: + """Specification for a batch configuration (workload shape only).""" + seq_lens: list[int] + query_lens: list[int] + + name: str = "unnamed" + + @property + def batch_size(self): + return len(self.seq_lens) + + def __post_init__(self): + assert len(self.seq_lens) == len(self.query_lens) + + def compute_num_tokens(self): + return sum(self.query_lens) + + +def create_common_attn_metadata( + batch_spec: BatchSpec, + block_size: int, + device: torch.device, + max_block_idx: int = 1000) -> CommonAttentionMetadata: + """Create CommonAttentionMetadata from a BatchSpec and ModelParams.""" + # Create query start locations + query_start_loc = torch.zeros(batch_spec.batch_size + 1, + dtype=torch.int32, + device=device) + query_start_loc[1:] = torch.tensor(batch_spec.query_lens, + dtype=torch.int32, + device=device).cumsum(0) + query_start_loc_cpu = query_start_loc.cpu() + num_tokens = batch_spec.compute_num_tokens() + + # Create sequence lengths + seq_lens = torch.tensor(batch_spec.seq_lens, + dtype=torch.int32, + device=device) + seq_lens_cpu = seq_lens.cpu() + + # Create computed tokens (context length for each sequence) + context_lens = [ + batch_spec.seq_lens[i] - batch_spec.query_lens[i] + for i in range(batch_spec.batch_size) + ] + num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32) + + # Create block table (random for testing) + max_blocks = max(batch_spec.seq_lens) // block_size + 1 + block_table_tensor = torch.randint(0, + max_block_idx, + (batch_spec.batch_size, max_blocks), + dtype=torch.int32, + device=device) + + # Create slot mapping + slot_mapping = torch.randint(0, + max_block_idx, (num_tokens, ), + dtype=torch.int64, + device=device) + + # Calculate max query length + max_query_len = max(batch_spec.query_lens) + + return CommonAttentionMetadata( + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + num_computed_tokens_cpu=num_computed_tokens_cpu, + num_reqs=batch_spec.batch_size, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + block_table_tensor=block_table_tensor, + slot_mapping=slot_mapping, + ) + + +def get_attention_backend(backend_name: _Backend): + """Set up attention backend classes for testing. + + Args: + backend_name: Name of the backend ("flash_attn", "flashinfer", etc.) + vllm_config: VllmConfig instance + + Returns: + Tuple of (backend_builder_class, backend_impl_class) + """ + backend_map = { + _Backend.FLASH_ATTN_VLLM_V1: + "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend", + _Backend.FLASHINFER_VLLM_V1: + "vllm.v1.attention.backends.flashinfer.FlashInferBackend", + _Backend.FLEX_ATTENTION: + "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", + _Backend.TRITON_ATTN_VLLM_V1: + "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", + } + + if backend_name not in backend_map: + raise ValueError(f"Unknown backend: {backend_name}") + + backend_class_name = backend_map[backend_name] + + try: + backend_class = resolve_obj_by_qualname(backend_class_name) + return backend_class.get_builder_cls(), backend_class.get_impl_cls() + except ImportError as e: + pytest.skip(f"{backend_name} not available: {e}") + + +def create_standard_kv_cache_spec( + vllm_config: VllmConfig) -> FullAttentionSpec: + """Create a FullAttentionSpec from ModelParams only.""" + return FullAttentionSpec( + block_size=vllm_config.cache_config.block_size, + num_kv_heads=vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config), + head_size=vllm_config.model_config.get_head_size(), + dtype=vllm_config.model_config.dtype, + use_mla=vllm_config.model_config.use_mla, + sliding_window=vllm_config.model_config.get_sliding_window(), + ) + + +def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", + tensor_parallel_size: int = 1, + max_model_len: int = 1024, + dtype: Union[ModelDType, torch.dtype] = "auto", + block_size: int = 16, + max_num_seqs: int = 256, + max_num_batched_tokens: int = 8192, + add_mock_model_methods: bool = True) -> VllmConfig: + """Create a VllmConfig for testing with reasonable defaults.""" + + model_config = ModelConfig( + model=model_name, + tokenizer=model_name, + trust_remote_code=False, + dtype=dtype, + seed=0, + max_model_len=max_model_len, + ) + + cache_config = CacheConfig( + block_size=block_size, + cache_dtype="auto", + swap_space=0, + ) + # Set cache blocks for testing + # (these may be set during initialization normally) + cache_config.num_gpu_blocks = 1000 + cache_config.num_cpu_blocks = 0 + + parallel_config = ParallelConfig( + tensor_parallel_size=tensor_parallel_size, ) + + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + ) + + device_config = DeviceConfig() + load_config = LoadConfig() + compilation_config = CompilationConfig() + + if add_mock_model_methods: + # Add mock methods to satisfy backends that need them + # This is a workaround because tests don't build full, real models, + # but some backends expect to query the model for layer-specific + # parameters + import types + model_config.get_num_layers = types.MethodType(lambda self: 1, + model_config) + model_config.get_sliding_window_for_layer = types.MethodType( + lambda self, i: None, model_config) + model_config.get_logits_soft_cap_for_layer = types.MethodType( + lambda self, i: 0.0, model_config) + model_config.get_sm_scale_for_layer = types.MethodType( + lambda self, i: 1.0 / model_config.get_head_size()**0.5, + model_config) + + return VllmConfig( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + load_config=load_config, + compilation_config=compilation_config, + ) + + +def create_dummy_kv_cache(block_size: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + num_blocks: int = 100) -> torch.Tensor: + """Create a dummy KV cache tensor for testing.""" + kv_cache = torch.randn( + num_blocks, + 2, # K and V + block_size, + num_kv_heads, + head_size, + dtype=dtype, + device=device) + return kv_cache diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 5efab2c1440..5c74a286c4a 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -6,6 +6,10 @@ import pytest import torch +from tests.v1.attention.utils import (BatchSpec, _Backend, + create_common_attn_metadata, + create_standard_kv_cache_spec, + get_attention_backend) from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) @@ -64,13 +68,19 @@ def test_prepare_inputs(): """ device = torch.device(current_platform.device_type) - # a = 4, b = 7, c = 5 + # q1 = 4, q2 = 7, q3 = 5 # n1 = 1, n2 = 3, n3 = 2 - # Cumulative lengths: [0, 4, 11, 16] - cu_target_query_lens = torch.tensor([0, 4, 11, 16], - dtype=torch.int32, - device=device) + batch_spec = BatchSpec( + seq_lens=[4, 7, 5], + query_lens=[4, 7, 5], + ) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, + block_size=16, + device=device, + ) # Rejected tokens per request: [1, 3, 2] num_rejected_tokens = torch.tensor([1, 3, 2], @@ -104,15 +114,13 @@ def test_prepare_inputs(): ], dtype=torch.int32, device=device) + proposer = _create_proposer("eagle", 1) - # n1 + n2 + n3 - a - b -c - num_tokens = cu_target_query_lens[-1].item() - num_rejected_tokens.sum( - ).item() + updated_metadata, token_indices = proposer.prepare_inputs( + common_attn_metadata, num_rejected_tokens.cpu()) - cu_num_tokens, token_indices = EagleProposer.prepare_inputs( - cu_target_query_lens, num_rejected_tokens, num_tokens) - - assert torch.equal(cu_num_tokens, expected_cu_num_tokens) + assert torch.equal(updated_metadata.query_start_loc, + expected_cu_num_tokens) assert token_indices.shape[0] == expected_cu_num_tokens[-1].item() assert torch.equal(token_indices, expected_token_indices) @@ -209,6 +217,7 @@ def test_propose(num_speculative_tokens): seq_len_2 = 3 total_tokens = seq_len_1 + seq_len_2 vocab_size = 100 + seq_lens = [seq_len_1, seq_len_2] # Create proposer first so we can use its actual hidden_size proposer = _create_proposer("eagle", num_speculative_tokens) @@ -270,9 +279,16 @@ def create_deterministic_logits(token_ids): proposer.attn_layer_names = ["layer.0"] # Create input tensors - cu_num_tokens = torch.tensor([0, seq_len_1, total_tokens], - dtype=torch.int32, - device=device) + batch_spec = BatchSpec( + seq_lens=seq_lens, + query_lens=seq_lens, + ) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, + block_size=16, + device=device, + ) target_token_ids = torch.randint(0, vocab_size, (total_tokens, ), @@ -284,25 +300,29 @@ def create_deterministic_logits(token_ids): target_hidden_states = torch.randn(total_tokens, hidden_size, device=device) - target_slot_mapping = torch.randint(0, - 100, (total_tokens, ), - device=device) next_token_ids = torch.randint(0, vocab_size, (batch_size, ), dtype=torch.int32, device=device) - block_table = torch.randint(0, 10, (batch_size, 10), device=device) - sampling_metadata = mock.MagicMock() - # Call the method under test + attn_metadata_builder_cls, _ = get_attention_backend( + _Backend.FLASH_ATTN_VLLM_V1) + attn_metadata_builder = attn_metadata_builder_cls( + kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), + vllm_config=proposer.vllm_config, + device=device, + ) + + # Mock runner for attention metadata building + proposer.runner = mock.MagicMock() + proposer.runner.attn_metadata_builders = [attn_metadata_builder] + result = proposer.propose(target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, - target_slot_mapping=target_slot_mapping, next_token_ids=next_token_ids, - cu_num_tokens=cu_num_tokens, - block_table=block_table, + common_attn_metadata=common_attn_metadata, sampling_metadata=sampling_metadata) assert result.shape == (batch_size, num_speculative_tokens) diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index f1c6bdfc1c9..d63b82012a5 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -12,13 +12,12 @@ AttentionMetadata, AttentionType, is_quantized_kv_cache) from vllm.attention.backends.utils import CommonAttentionState +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.block_table import BlockTable -from vllm.v1.worker.cpu_model_runner import CPUModelRunner from vllm.v1.worker.gpu_input_batch import InputBatch try: @@ -316,19 +315,21 @@ def get_seq_len_block_table_args( class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): - def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec, - block_table: BlockTable) -> None: - self.runner = runner - self.block_table = block_table + def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, + device: torch.device) -> None: + self.kv_cache_spec = kv_cache_spec + self.vllm_config = vllm_config + self.scheduler_config = vllm_config.scheduler_config + # For reorder - self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs, - dtype=np.int64) - self.reorder_decode_req_index_list = np.empty(self.runner.max_num_reqs, - dtype=np.int64) + self.reorder_prompt_req_index_list = np.empty( + vllm_config.scheduler_config.max_num_seqs, dtype=np.int64) + self.reorder_decode_req_index_list = np.empty( + vllm_config.scheduler_config.max_num_seqs, dtype=np.int64) self.num_prompt_req: int = 0 self.seq_start_loc_cpu = torch.zeros( - runner.max_num_reqs + 1, + vllm_config.scheduler_config.max_num_seqs + 1, dtype=torch.int32, device="cpu", ) @@ -378,15 +379,15 @@ def reorder_batch(self, input_batch: InputBatch, return True - def build(self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata): + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> TorchSDPAMetadata: num_reqs = common_attn_metadata.num_reqs - num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - runner = self.runner - block_table = self.block_table - seq_lens_np = runner.seq_lens_np[:num_reqs] + seq_lens_cpu = common_attn_metadata.seq_lens_cpu + seq_lens_np = seq_lens_cpu.numpy() num_prompt_req = self.num_prompt_req max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item( ) if num_prompt_req > 0 else 0 @@ -394,34 +395,36 @@ def build(self, common_prefix_len: int, ) if num_prompt_req < num_reqs else 0 self.seq_start_loc_np[0] = 0 np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1]) - num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item() - num_decode_tokens = runner.query_start_loc_np[num_reqs].item( - ) - num_prefill_tokens - slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].long() - block_table_tensor = block_table.get_device_tensor() + + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + num_prefill_tokens = int(query_start_loc_cpu[num_prompt_req].item()) + num_decode_tokens = int(query_start_loc_cpu[num_reqs].item() - + num_prefill_tokens) + + slot_mapping = common_attn_metadata.slot_mapping.long() + block_table_tensor = common_attn_metadata.block_table_tensor + attn_metadata = TorchSDPAMetadata( num_prefills=num_prompt_req, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, slot_mapping=slot_mapping, # to ensure inference when chunked_prefill is disabled - seq_lens=runner.seq_lens_cpu[:num_reqs].tolist(), - seq_lens_tensor=runner. - seq_lens_cpu[num_prompt_req:num_reqs], # decode + seq_lens=seq_lens_cpu.tolist(), + seq_lens_tensor=seq_lens_cpu[num_prompt_req:num_reqs], # decode max_decode_seq_len=max_decode_seq_len, # decode block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode - chunked_prefill=self.runner.scheduler_config. - chunked_prefill_enabled, + chunked_prefill=self.scheduler_config.chunked_prefill_enabled, max_query_len=max_query_len, max_kv_len=max_prefill_seq_len, - prefill_query_start_loc=runner. - query_start_loc_cpu[:num_prompt_req + 1], # prefill + prefill_query_start_loc=query_start_loc_cpu[:num_prompt_req + + 1], # prefill kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req + 1], # prefill prefill_block_tables=block_table_tensor[: num_prompt_req], # prefill - query_start_loc=runner.query_start_loc_cpu[:num_reqs + - 1], # for logits index + query_start_loc=query_start_loc_cpu[:num_reqs + + 1], # for logits index multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=False, ) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 552c2caf2fa..4224d807c2b 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with FlashAttention.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Optional +from typing import Any, ClassVar, Optional import numpy as np import torch @@ -29,10 +29,6 @@ AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, make_local_attention_virtual_batches) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.block_table import BlockTable - -if TYPE_CHECKING: - from vllm.v1.worker.gpu_model_runner import GPUModelRunner logger = init_logger(__name__) @@ -162,29 +158,30 @@ class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3 - def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, - block_table: BlockTable): - model_config = runner.model_config - compilation_config = runner.vllm_config.compilation_config - - self.runner = runner - self.num_heads_q = model_config.get_num_attention_heads( - runner.parallel_config) - self.num_heads_kv = model_config.get_num_kv_heads( - runner.parallel_config) - self.headdim = model_config.get_head_size() + def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, + device: torch.device): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.parallel_config = vllm_config.parallel_config + self.cache_config = vllm_config.cache_config + self.compilation_config = vllm_config.compilation_config + self.device = device + + self.num_heads_q = self.model_config.get_num_attention_heads( + self.parallel_config) + self.num_heads_kv = self.model_config.get_num_kv_heads( + self.parallel_config) + self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size - self.kv_cache_spec = kv_cache_spec - self.block_table = block_table self.max_num_splits = 0 # No upper bound on the number of splits. self.aot_schedule = (get_flash_attn_version() == 3) - self.use_full_cuda_graph = compilation_config.full_cuda_graph + self.use_full_cuda_graph = self.compilation_config.full_cuda_graph if self.use_full_cuda_graph: if not self.aot_schedule: raise ValueError( "AoT scheduling is required for full cuda graph.") - capture_sizes = compilation_config.cudagraph_capture_sizes + capture_sizes = self.compilation_config.cudagraph_capture_sizes if not capture_sizes: raise ValueError( "cudagraph_capture_sizes should not be None when " @@ -198,9 +195,9 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, "full cuda graph.") self.scheduler_metadata = torch.zeros( - self.runner.max_num_reqs + 1, + vllm_config.scheduler_config.max_num_seqs + 1, dtype=torch.int32, - device=self.runner.device, + device=self.device, ) # When using cuda graph, we need to set the upper bound of the # number of splits so that large enough intermediate buffers are @@ -211,28 +208,27 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, # populated on first build() call. self.aot_sliding_window: Optional[tuple[int, int]] = None - def build( - self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata - ) -> FlashAttentionMetadata: + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> FlashAttentionMetadata: + """ + fast_build disables AOT scheduling, used when there will be few + iterations i.e. spec-decode + """ num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - - max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) + max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens - block_table = self.block_table - block_table_tensor = block_table.get_device_tensor()[:num_reqs] - - block_table.slot_mapping[:num_actual_tokens].copy_( - block_table.slot_mapping_cpu[:num_actual_tokens], - non_blocking=True) - # Fill unused with -1. Needed for reshape_and_cache in full cuda graph - # mode. - block_table.slot_mapping[num_actual_tokens:].fill_(-1) + seq_lens_cpu = common_attn_metadata.seq_lens_cpu + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping - slot_mapping = block_table.slot_mapping[:num_actual_tokens] + # the overhead of the aot schedule is not worth it for spec-decode + aot_schedule = self.aot_schedule and not fast_build if self.aot_sliding_window is None: self.aot_sliding_window = (-1, -1) @@ -240,19 +236,20 @@ def build( # constant for all layers to. We have to populate this on the first # build() call so the layers are constructed (cannot populate) # in __init__. - if self.aot_schedule: + if aot_schedule: sliding_window_configs = _get_sliding_window_configs( - self.runner.vllm_config) + self.vllm_config) if len(sliding_window_configs) == 1: sliding_window_config = sliding_window_configs.pop() if sliding_window_config is not None: self.aot_sliding_window = sliding_window_config elif len(sliding_window_configs) > 1: self.aot_schedule = False + aot_schedule = False def schedule(batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal): - if self.aot_schedule: + if aot_schedule: return get_scheduler_metadata( batch_size=batch_size, max_seqlen_q=max_query_len, @@ -271,19 +268,19 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, # for local attention local_attn_metadata = None - if self.runner.attention_chunk_size is not None: + if self.model_config.attention_chunk_size is not None: seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \ virt_block_table_tensor = make_local_attention_virtual_batches( - self.runner.attention_chunk_size, - self.runner.query_start_loc_np[:num_reqs + 1], - self.runner.seq_lens_np[:num_reqs], + self.model_config.attention_chunk_size, + query_start_loc_cpu.numpy(), + seq_lens_cpu.numpy(), block_table_tensor, self.block_size, ) local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to( - self.runner.device, non_blocking=True) + self.device, non_blocking=True) local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to( - self.runner.device, non_blocking=True) + self.device, non_blocking=True) local_max_query_len = seqlens_q_local_np.max() local_max_seq_len = virt_k_seqlens_np.max() local_scheduler_metadata = schedule( @@ -308,14 +305,12 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, if use_cascade: cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], dtype=torch.int32, - device=self.runner.device) + device=self.device) prefix_kv_lens = torch.tensor([common_prefix_len], dtype=torch.int32, - device=self.runner.device) - suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] - - common_prefix_len) - suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to( - self.runner.device) + device=self.device) + suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to( + self.device, non_blocking=True) prefix_scheduler_metadata = schedule( batch_size=1, cu_query_lens=cu_prefix_query_lens, diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index f922e6e4c9e..1eb27d57acf 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -15,22 +15,20 @@ import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import use_cascade_attention -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata, - PerLayerParameters, - get_kv_cache_layout, - get_per_layer_parameters, - infer_global_hyperparameters) +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters, + get_kv_cache_layout, get_per_layer_parameters, + infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills, + split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch - from vllm.v1.worker.gpu_model_runner import GPUModelRunner FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 @@ -226,9 +224,9 @@ def __post_init__(self): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): - def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec, - block_table: BlockTable): - self.runner = runner + def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, + device: torch.device): + self.device = device self._workspace_buffer = None self._prefill_wrapper = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode @@ -237,75 +235,22 @@ def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec, # Global hyperparameters shared by all attention layers self.global_hyperparameters: Optional[PerLayerParameters] = None - self.vllm_config = runner.vllm_config + self.vllm_config = vllm_config + self.cache_config = vllm_config.cache_config self.kv_cache_spec = kv_cache_spec - self.block_table = block_table def reorder_batch(self, input_batch: InputBatch, scheduler_output: SchedulerOutput) -> bool: - # We now want to reorder the batch so that the "decode" requests are and - # the front and the "prefill" requests are at the using the least amount - # swaps possible. (NOTE for now we loosely use "decode" to mean requests - # where attention is likely memory-bound and "prefill" to mean requests - # where attention is likely compute-bound, TODO(lucas): figure out a - # better naming here) - decodes = [] - prefills = [] - num_decode_tokens = 0 - num_prefill_tokens = 0 - - for i, req_id in enumerate(input_batch.req_ids): - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - # for now treat 1 scheduled token as "decode" even if its not, - # we should update this to something like < 8 in the future but - # currently the decode run only supports num_tokens = 1 - if num_tokens == 1: - decodes.append(i) - num_decode_tokens += num_tokens - else: - prefills.append(i) - num_prefill_tokens += num_tokens - - # We hope that this is fairly minimal since decodes - # should be around for a number of iterations so hopefully they are - # relatively stationary (and new request are generally appended to the - # persistent batch so already should be at the back) - # To achieve this we loop over the decodes in descending order and - # the prefills in ascending order. We swap decodes from the "back" - # i.e. past where the last decode should be in the reodorered with - # prefills from the front of the batch. - # `decodes` and `prefills` are already in ascending order just based on - # the above loop - num_decodes = len(decodes) - num_prefills = len(prefills) - modified_batch = False - - for i in range(1, min(num_decodes, num_prefills) + 1): - # If the decode is at the "back" of the batch, i, we can swap it - # with the prefill closest to the front of the batch - decode_idx = decodes[num_decodes - i] - if decode_idx < num_decodes: - break - - input_batch.swap_states(prefills[i - 1], decode_idx) - modified_batch = True - - # Save for next `build` call - # TODO(lucas): this is a bit of a hack, we should probably have a - # better way of doing this - self._num_decodes = num_decodes - self._num_prefills = num_prefills - self._num_decode_tokens = num_decode_tokens - self._num_prefill_tokens = num_prefill_tokens - - return modified_batch + return reorder_batch_to_split_decodes_and_prefills(input_batch, + scheduler_output, + decode_threshold=1) def _get_workspace_buffer(self): if self._workspace_buffer is None: self._workspace_buffer = torch.empty( FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, - device=self.runner.device) + device=self.device) return self._workspace_buffer def _get_prefill_wrapper(self): @@ -316,10 +261,11 @@ def _get_prefill_wrapper(self): def _get_decode_wrapper(self): if self._decode_wrapper is None: - num_qo_heads = (self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config)) - num_kv_heads = self.runner.model_config.get_num_kv_heads( - self.runner.parallel_config) + num_qo_heads = ( + self.vllm_config.model_config.get_num_attention_heads( + self.vllm_config.parallel_config)) + num_kv_heads = self.vllm_config.model_config.get_num_kv_heads( + self.vllm_config.parallel_config) use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( num_qo_heads // num_kv_heads > 4) self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( @@ -334,7 +280,8 @@ def _get_cascade_wrapper(self): 2, self._get_workspace_buffer(), get_kv_cache_layout()) return self._cascade_wrapper - def _plan(self, attn_metadata: FlashInferMetadata): + def _plan(self, num_prefills: int, num_decodes: int, + attn_metadata: FlashInferMetadata): if self.global_hyperparameters is None: self.global_hyperparameters = infer_global_hyperparameters( get_per_layer_parameters(self.vllm_config, FlashInferImpl)) @@ -369,16 +316,16 @@ def _plan(self, attn_metadata: FlashInferMetadata): # Regular attention (common case). # Decodes are at the front and prefills are at the back, # according to reorder_batch() - if self._num_prefills > 0: + if num_prefills > 0: # Decodes are first so prefills start after the last decode - prefill_start = self._num_decodes + prefill_start = num_decodes attn_metadata.prefill_wrapper = self._get_prefill_wrapper() assert attn_metadata.qo_indptr[prefill_start:].shape[ - 0] == self._num_prefills + 1 + 0] == num_prefills + 1 assert attn_metadata.paged_kv_indptr[prefill_start:].shape[ - 0] == self._num_prefills + 1 + 0] == num_prefills + 1 assert attn_metadata.paged_kv_last_page_len[ - prefill_start:].shape[0] == self._num_prefills + prefill_start:].shape[0] == num_prefills # Since prefill_wrapper.run() will be called with # query[num_decode_tokens:] we need to adjust the qo_indptr # to be relative to the start of the prefill queries. @@ -402,17 +349,16 @@ def _plan(self, attn_metadata: FlashInferMetadata): kv_data_type=attn_metadata.kv_data_type, ) - if self._num_decodes > 0: + if num_decodes > 0: attn_metadata.decode_wrapper = self._get_decode_wrapper() if not FlashInferBackend.use_trtllm_decode_attention( - self._num_decodes, attn_metadata.max_seq_len, + num_decodes, attn_metadata.max_seq_len, attn_metadata.kv_data_type, attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim): attn_metadata.decode_wrapper.plan( - attn_metadata.paged_kv_indptr[:self._num_decodes + 1], + attn_metadata.paged_kv_indptr[:num_decodes + 1], attn_metadata.paged_kv_indices, - attn_metadata.paged_kv_last_page_len[:self. - _num_decodes], + attn_metadata.paged_kv_last_page_len[:num_decodes], attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim, @@ -427,22 +373,20 @@ def _plan(self, attn_metadata: FlashInferMetadata): kv_data_type=attn_metadata.kv_data_type, ) - def build(self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata): - num_reqs = common_attn_metadata.num_reqs + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> FlashInferMetadata: num_actual_tokens = common_attn_metadata.num_actual_tokens + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ + split_decodes_and_prefills(common_attn_metadata) - assert self._num_decodes + self._num_prefills == num_reqs - assert (self._num_decode_tokens + - self._num_prefill_tokens == num_actual_tokens) page_size = self.kv_cache_spec.block_size - device = self.runner.device + device = self.device qo_indptr = common_attn_metadata.query_start_loc - max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) + max_seq_len = common_attn_metadata.seq_lens_cpu.max() seq_lens = common_attn_metadata.seq_lens - block_table_tensor = self.block_table.get_device_tensor()[:num_reqs] - slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to( - self.runner.device, non_blocking=True).long() + block_table_tensor = common_attn_metadata.block_table_tensor block_table_bounds = (seq_lens + page_size - 1) // page_size @@ -487,7 +431,7 @@ def build(self, common_prefix_len: int, paged_kv_last_page_len = seq_lens % page_size paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len) - cache_dtype = self.runner.cache_config.cache_dtype + cache_dtype = self.cache_config.cache_dtype if cache_dtype.startswith("fp8"): kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( cache_dtype) @@ -499,17 +443,18 @@ def build(self, common_prefix_len: int, paged_kv_indptr=paged_kv_indptr, paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len, - num_qo_heads=self.runner.num_query_heads, + num_qo_heads=self.vllm_config.model_config.get_num_attention_heads( + self.vllm_config.parallel_config), num_kv_heads=self.kv_cache_spec.num_kv_heads, head_dim=self.kv_cache_spec.head_size, page_size=page_size, kv_data_type=kv_cache_dtype, - q_data_type=self.runner.dtype, - slot_mapping=slot_mapping, - num_decodes=self._num_decodes, - num_decode_tokens=self._num_decode_tokens, - num_prefills=self._num_prefills, - num_prefill_tokens=self._num_prefill_tokens, + q_data_type=self.vllm_config.model_config.dtype, + slot_mapping=common_attn_metadata.slot_mapping, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, use_cascade=use_cascade, shared_qo_indptr=shared_qo_indptr, shared_kv_page_indptr=shared_kv_page_indptr, @@ -521,12 +466,12 @@ def build(self, common_prefix_len: int, workspace_buffer=self._workspace_buffer, ) - self._plan(attn_metadata) + self._plan(num_prefills, num_decodes, attn_metadata) return attn_metadata def use_cascade_attention(self, *args, **kwargs) -> bool: - if self.kv_cache_spec.dtype != self.runner.model_config.dtype: + if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype: # TODO: The cascade wrapper currently does not support setting # kv cache dtype to something different from query dtype. return False diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index f0f54c28831..c229ec12fd1 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -3,7 +3,7 @@ """Attention layer with FlashAttention.""" from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import Any, Optional import torch from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature, @@ -14,18 +14,15 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType, is_quantized_kv_cache) +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.block_table import BlockTable logger = init_logger(__name__) -if TYPE_CHECKING: - from vllm.v1.worker.gpu_model_runner import GPUModelRunner - create_block_mask_compiled = torch.compile(create_block_mask, fullgraph=True, mode="reduce-overhead") @@ -261,36 +258,34 @@ def __post_init__(self): class FlexAttentionMetadataBuilder( AttentionMetadataBuilder[FlexAttentionMetadata]): - def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, - block_table: BlockTable): - model_config = runner.model_config - - self.runner = runner - self.num_heads_q = model_config.get_num_attention_heads( - runner.parallel_config) - self.num_heads_kv = model_config.get_num_kv_heads( - runner.parallel_config) - self.headdim = model_config.get_head_size() + def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, + device: torch.device): + self.model_config = vllm_config.model_config + self.parallel_config = vllm_config.parallel_config + self.cache_config = vllm_config.cache_config + + self.num_heads_q = self.model_config.get_num_attention_heads( + vllm_config.parallel_config) + self.num_heads_kv = self.model_config.get_num_kv_heads( + vllm_config.parallel_config) + self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec - self.block_table = block_table + self.device = device - def build(self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata): + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> FlexAttentionMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = self.runner.seq_lens_np[:num_reqs].max() + max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens - - block_table = self.block_table - block_table_tensor = block_table.get_device_tensor()[:num_reqs] - block_table.slot_mapping[:num_actual_tokens].copy_( - block_table.slot_mapping_cpu[:num_actual_tokens], - non_blocking=True) - slot_mapping = block_table.slot_mapping[:num_actual_tokens] + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping use_cascade = common_prefix_len > 0 cu_prefix_query_lens = None @@ -300,17 +295,15 @@ def build(self, common_prefix_len: int, raise NotImplementedError("Not yet my friend") block_size = self.kv_cache_spec.block_size - max_possible_seq_len = self.runner.model_config.max_model_len - total_cache_tokens = (self.runner.cache_config.num_gpu_blocks * - block_size) + max_possible_seq_len = self.model_config.max_model_len + total_cache_tokens = self.cache_config.num_gpu_blocks * block_size inverse_block_table = physical_to_logical_mapping( - block_table_tensor, self.runner.cache_config.num_gpu_blocks) + block_table_tensor, self.cache_config.num_gpu_blocks) # Get the original offset tensor - offset_tensor = torch.tensor( - self.runner.input_batch.num_computed_tokens_cpu[:num_reqs]).to( - self.runner.device, non_blocking=True) + offset_tensor = common_attn_metadata.num_computed_tokens_cpu.to( + self.device, non_blocking=True) out = FlexAttentionMetadata( num_actual_tokens=num_actual_tokens, diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 7b4ecd7c359..dca5de46c06 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -7,15 +7,15 @@ import torch from vllm.attention.backends.abstract import AttentionBackend -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) -from vllm.v1.kv_cache_interface import MambaSpec -from vllm.v1.worker.block_table import BlockTable +from vllm.config import VllmConfig +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, CommonAttentionMetadata, + reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch - from vllm.v1.worker.gpu_model_runner import GPUModelRunner def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, @@ -87,80 +87,24 @@ class Mamba2AttentionMetadata: class Mamba2AttentionMetadataBuilder( AttentionMetadataBuilder[Mamba2AttentionMetadata]): - def __init__(self, runner: "GPUModelRunner", kv_cache_spec: MambaSpec, - block_table: BlockTable): - self.runner = runner + def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, + device: torch.device): + assert isinstance(kv_cache_spec, MambaSpec) self.kv_cache_spec = kv_cache_spec - self.block_table = block_table - self.chunk_size = runner.vllm_config.model_config.get_mamba_chunk_size( - ) + self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() assert self.chunk_size is not None, ( "chunk_size needs to be set in the model config for Mamba2 models") def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: - # NOTE (Chen): Copied from MLACommonMetadataBuilder and - # FlashInferMetadataBuilder. Should be refactored later to avoid code - # duplication of these 3 functions. - # We now want to reorder the batch so that the "decode" requests are and - # the front and the "prefill" requests are at the using the least amount - # swaps possible. (NOTE for now we loosely use "decode" to mean requests - # where attention is likely memory-bound and "prefill" to mean requests - # where attention is likely compute-bound, TODO(lucas): figure out a - # better naming here) - decodes = [] - prefills = [] - num_decode_tokens = 0 - num_prefill_tokens = 0 - - for i, req_id in enumerate(input_batch.req_ids): - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - # for now treat 1 scheduled token as "decode" even if its not, - # we should update this to something like < 8 in the future but - # currently the decode run only supports num_tokens = 1 - if num_tokens == 1: - decodes.append(i) - num_decode_tokens += num_tokens - else: - prefills.append(i) - num_prefill_tokens += num_tokens - - # We hope that this is fairly minimal since decodes - # should be around for a number of iterations so hopefully they are - # relatively stationary (and new request are generally appended to the - # persistent batch so already should be at the back) - # To achieve this we loop over the decodes in descending order and - # the prefills in ascending order. We swap decodes from the "back" - # i.e. past where the last decode should be in the reodorered with - # prefills from the front of the batch. - # `decodes` and `prefills` are already in ascending order just based on - # the above loop - num_decodes = len(decodes) - num_prefills = len(prefills) - modified_batch = False - - for i in range(1, min(num_decodes, num_prefills) + 1): - # If the decode is at the "back" of the batch, i, we can swap it - # with the prefill closest to the front of the batch - decode_idx = decodes[num_decodes - i] - if decode_idx < num_decodes: - break - - input_batch.swap_states(prefills[i - 1], decode_idx) - modified_batch = True - - # Save for next `build` call - # TODO(lucas): this is a bit of a hack, we should probably have a - # better way of doing this - self._num_decodes = num_decodes - self._num_prefills = num_prefills - self._num_decode_tokens = num_decode_tokens - self._num_prefill_tokens = num_prefill_tokens - - return modified_batch - - def build(self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata): + return reorder_batch_to_split_decodes_and_prefills(input_batch, + scheduler_output, + decode_threshold=1) + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> Mamba2AttentionMetadata: num_reqs = common_attn_metadata.num_reqs query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens @@ -172,29 +116,31 @@ def build(self, common_prefix_len: int, has_initial_states = None prep_initial_states = False - state_indices_tensor = self.block_table.block_table[:num_reqs, 0] + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills(common_attn_metadata, + decode_threshold=1)) # Compute seq_idx, chunk_indices and chunk_offsets for prefill only - if self._num_prefills > 0: + if num_prefills > 0: #[batch,] has_initial_states_cpu = ( - self.runner.input_batch. - num_computed_tokens_cpu_tensor[num_reqs - - self._num_prefills:num_reqs] - > 0) + common_attn_metadata. + num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0) prep_initial_states = torch.any(has_initial_states_cpu).item() has_initial_states = has_initial_states_cpu.to( query_start_loc.device) query_start_loc_p = common_attn_metadata.query_start_loc[ - -self._num_prefills - 1:] - self._num_decode_tokens - - seq_idx = torch.repeat_interleave( - torch.arange(self._num_prefills, - dtype=torch.int32, - device=query_start_loc_p.device), - query_start_loc_p.diff(), - output_size=self._num_prefill_tokens) + -num_prefills - 1:] - num_decode_tokens + + seq_idx = torch.repeat_interleave(torch.arange( + num_prefills, + dtype=torch.int32, + device=query_start_loc_p.device), + query_start_loc_p.diff(), + output_size=num_prefill_tokens) seq_idx.unsqueeze_(0) # We compute metadata for chunked prefill once at the top level @@ -204,13 +150,13 @@ def build(self, common_prefix_len: int, chunk_indices, chunk_offsets = ( _query_start_loc_to_chunk_indices_offsets( query_start_loc_p, self.chunk_size, - self._num_prefill_tokens)) + num_prefill_tokens)) attn_metadata = Mamba2AttentionMetadata( - num_prefills=self._num_prefills, - num_prefill_tokens=self._num_prefill_tokens, - num_decodes=self._num_decodes, - num_decode_tokens=self._num_decode_tokens, + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, query_start_loc=query_start_loc, seq_lens=seq_lens, has_initial_states=has_initial_states, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 173c8466f6d..93c8156b16a 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -202,18 +202,18 @@ from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import get_flash_attn_version +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, UnquantizedLinearMethod) from vllm.platforms import current_platform from vllm.utils import cdiv, round_down -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata, - get_per_layer_parameters, - infer_global_hyperparameters) +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, CommonAttentionMetadata, + get_per_layer_parameters, infer_global_hyperparameters, + reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.block_table import BlockTable try: from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -235,7 +235,6 @@ if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch - from vllm.v1.worker.gpu_model_runner import GPUModelRunner logger = init_logger(__name__) @@ -406,22 +405,23 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): """ def __init__(self, - runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, - block_table: BlockTable, + vllm_config: VllmConfig, + device: torch.device, metadata_cls: Optional[type[M]] = None): self.metadata_cls = metadata_cls \ if metadata_cls is not None else MLACommonMetadata - self.runner = runner - scheduler_config = runner.scheduler_config - model_config = runner.model_config - cache_config = runner.cache_config + self.kv_cache_spec = kv_cache_spec + self.device = device + scheduler_config = vllm_config.scheduler_config + self.model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + parallel_config = vllm_config.parallel_config self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled - self.num_heads = model_config.get_num_attention_heads( - runner.parallel_config) - self.mla_dims = get_mla_dims(model_config) + self.num_heads = self.model_config.get_num_attention_heads( + parallel_config) + self.mla_dims = get_mla_dims(self.model_config) self.aot_schedule = current_platform.is_cuda() - self.kv_cache_spec = kv_cache_spec # Dont try to access the runner on AMD if self.aot_schedule: @@ -432,7 +432,7 @@ def __init__(self, # Max sure there is enough for 8 full length request or at least # 4 pages of cache per request max( - 8 * model_config.max_model_len, 4 * + 8 * self.model_config.max_model_len, 4 * scheduler_config.max_num_seqs * cache_config.block_size), # For long-context models try not to over-allocate limiting # kv-cache space, limiting it to 64k tokens, @@ -447,13 +447,11 @@ def __init__(self, scheduler_config.max_num_seqs * cache_config.block_size self.chunked_prefill_workspace = torch.empty( (self.chunked_prefill_workspace_size, - model_config.get_head_size()), - dtype=model_config.dtype, - device=runner.device, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, ) - self.block_table = block_table - self._use_cudnn_prefill = use_cudnn_prefill() self._use_fi_prefill = use_flashinfer_prefill() self.prefill_metadata_cls = ( @@ -465,7 +463,7 @@ def __init__(self, self._workspace_buffer = torch.empty( FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, - device=runner.device) + device=device) self._fi_prefill_main: Optional[ BatchPrefillWithRaggedKVCacheWrapper] = None @@ -473,13 +471,13 @@ def __init__(self, BatchPrefillWithRaggedKVCacheWrapper] = [] self._global_hyperparameters = infer_global_hyperparameters( - get_per_layer_parameters(runner.vllm_config, MLACommonImpl)) + get_per_layer_parameters(vllm_config, MLACommonImpl)) if self._use_cudnn_prefill: self.cudnn_workspace = torch.empty( CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs, dtype=torch.int8, - device=runner.device, + device=device, ) def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): @@ -505,7 +503,7 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): assert num_chunks <= len(self._fi_prefill_chunks) # In MLA, the non-latent num_qo_heads == num_kv_heads - num_qo_heads = self.runner.num_query_heads + num_qo_heads = self.num_heads num_kv_heads = num_qo_heads # Sanity: Verify that num_kv_heads == 1 since it is latent space @@ -531,7 +529,7 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): sm_scale=self._global_hyperparameters.sm_scale, window_left=self._global_hyperparameters.window_left, logits_soft_cap=self._global_hyperparameters.logits_soft_cap, - q_data_type=self.runner.dtype, + q_data_type=self.model_config.dtype, kv_data_type=self.kv_cache_spec.dtype, ) @@ -552,7 +550,7 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): window_left=self._global_hyperparameters.window_left, logits_soft_cap=self._global_hyperparameters. logits_soft_cap, - q_data_type=self.runner.dtype, + q_data_type=self.model_config.dtype, kv_data_type=self.kv_cache_spec.dtype, ) @@ -561,63 +559,9 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: - # We now want to reorder the batch so that the "decode" requests are and - # the front and the "prefill" requests are at the using the least amount - # swaps possible. (NOTE for now we loosely use "decode" to mean requests - # where attention is likely memory-bound and "prefill" to mean requests - # where attention is likely compute-bound, TODO(lucas): figure out a - # better naming here) - decodes = [] - prefills = [] - num_decode_tokens = 0 - num_prefill_tokens = 0 - - for i, req_id in enumerate(input_batch.req_ids): - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - # for now treat 1 scheduled token as "decode" even if its not, - # we should update this to something like < 8 in the future but - # currently the TritonMLA._forward_decode only supports - # num_tokens = 1 - if num_tokens == 1: - decodes.append(i) - num_decode_tokens += num_tokens - else: - prefills.append(i) - num_prefill_tokens += num_tokens - - # We hope that this is fairly minimal since decodes - # should be around for a number of iterations so hopefully they are - # relatively stationary (and new request are generally appended to the - # persistent batch so already should be at the back) - # To achieve this we loop over the decodes in descending order and - # the prefills in ascending order. We swap decodes from the "back" - # i.e. past where the last decode should be in the reodorered with - # prefills from the front of the batch. - # `decodes` and `prefills` are already in ascending order just based on - # the above loop - num_decodes = len(decodes) - num_prefills = len(prefills) - modified_batch = False - - for i in range(1, min(num_decodes, num_prefills) + 1): - # If the decode is at the "back" of the batch, i, we can swap it - # with the prefill closest to the front of the batch - decode_idx = decodes[num_decodes - i] - if decode_idx < num_decodes: - break - - input_batch.swap_states(prefills[i - 1], decode_idx) - modified_batch = True - - # Save for next `build` call - # TODO(lucas): this is a bit of a hack, we should probably have a - # better way of doing this - self._num_decodes = num_decodes - self._num_prefills = num_prefills - self._num_decode_tokens = num_decode_tokens - self._num_prefill_tokens = num_prefill_tokens - - return modified_batch + return reorder_batch_to_split_decodes_and_prefills(input_batch, + scheduler_output, + decode_threshold=1) def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens: torch.Tensor): @@ -639,49 +583,50 @@ def build_for_cudagraph_capture( m.max_query_len = 1 # decode-only - # Update state usually set in reorder_batch. - self._num_decodes = m.num_reqs - self._num_decode_tokens = m.num_actual_tokens - self._num_prefills = 0 - self._num_prefill_tokens = 0 return self.build(0, m) - def build(self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata) -> M: + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> M: num_reqs = common_attn_metadata.num_reqs - num_actual_tokens = common_attn_metadata.num_actual_tokens + num_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - assert self._num_decodes + self._num_prefills == num_reqs - # Note(simon): be careful about the CPU <> GPU memory movement in this # function. We should avoid GPU -> CPU sync as much as possible because # it blocks on all previous kernels. - device = self.runner.device - block_table = self.block_table - block_table_tensor = block_table.get_device_tensor()[:num_reqs] - block_table.slot_mapping[:num_actual_tokens].copy_( - block_table.slot_mapping_cpu[:num_actual_tokens], - non_blocking=True) - block_table.slot_mapping[num_actual_tokens:].fill_(-1) - slot_mapping = block_table.slot_mapping[:num_actual_tokens] + device = self.device + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens + query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + + num_computed_tokens_cpu = (common_attn_metadata.seq_lens_cpu - + query_seq_lens_cpu) + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + split_decodes_and_prefills(common_attn_metadata) + + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_tokens + prefill_metadata = None - if self._num_prefills > 0: - reqs_start = self._num_decodes # prefill_start + if num_prefills > 0: + reqs_start = num_decodes # prefill_start - context_lens_cpu = self.runner.input_batch.\ - num_computed_tokens_cpu_tensor[reqs_start:num_reqs] + context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] max_context_len_cpu = context_lens_cpu.max().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() prefill_query_start_loc = query_start_loc[ reqs_start:] - query_start_loc[reqs_start] chunked_context_metadata = None - if self.chunked_prefill_enabled and self._num_prefills > 0 \ + if self.chunked_prefill_enabled and num_prefills > 0 \ and max_context_len_cpu > 0: # NOTE: it is recommend you read the `Chunked Prefill` section # in the comment at the top of the file before trying to @@ -712,14 +657,14 @@ def build(self, common_prefix_len: int, # of `to_list`. chunk_starts = \ torch.arange(num_chunks, dtype=torch.int32) \ - .unsqueeze(1).expand(-1, self._num_prefills) \ + .unsqueeze(1).expand(-1, num_prefills) \ * max_context_chunk chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk) chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) cu_seq_lens_cpu = torch.zeros(num_chunks, - self._num_prefills + 1, + num_prefills + 1, dtype=torch.int32, pin_memory=True) torch.cumsum(chunk_seq_lens, @@ -762,28 +707,28 @@ def build(self, common_prefix_len: int, prefill_metadata.cudnn_workspace = self.cudnn_workspace decode_metadata = None - if self._num_decodes > 0: + if num_decodes > 0: decode_metadata = self._build_decode( - block_table_tensor=block_table_tensor[:self._num_decodes, ...], - seq_lens=seq_lens[:self._num_decodes], + block_table_tensor=block_table_tensor[:num_decodes, ...], + seq_lens=seq_lens[:num_decodes], ) attn_metadata = self.metadata_cls( num_reqs=common_attn_metadata.num_reqs, max_query_len=common_attn_metadata.max_query_len, - num_actual_tokens=num_actual_tokens, + num_actual_tokens=num_tokens, query_start_loc=query_start_loc, slot_mapping=slot_mapping, - head_dim=self.runner.model_config.get_head_size(), + head_dim=self.model_config.get_head_size(), # MLACommonMetadata Chunk prefill specific - num_decodes=self._num_decodes, - num_decode_tokens=self._num_decode_tokens, - num_prefills=self._num_prefills, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, prefill=prefill_metadata, decode=decode_metadata, ) - if self._use_fi_prefill and self._num_prefills > 0: + if self._use_fi_prefill and num_prefills > 0: assert isinstance(attn_metadata.prefill, FlashInferPrefillMetadata) self._build_fi_prefill_wrappers(attn_metadata.prefill) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index be26e0060db..935311aacc3 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -11,6 +11,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, get_mla_metadata, is_flashmla_supported) +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonDecodeMetadata, @@ -18,7 +19,6 @@ MLACommonMetadata, MLACommonMetadataBuilder) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.block_table import BlockTable logger = init_logger(__name__) @@ -56,12 +56,13 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): full_cudagraph_supported: ClassVar[bool] = True # Decode-only - def __init__(self, runner, kv_cache_spec: AttentionSpec, - block_table: BlockTable): - super().__init__(runner, kv_cache_spec, block_table, FlashMLAMetadata) + def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, + device: torch.device): + super().__init__(kv_cache_spec, vllm_config, device, FlashMLAMetadata) - self.num_q_heads = self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config) + self.compilation_config = vllm_config.compilation_config + self.num_q_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config) self.cg_buf_tile_scheduler_metadata = None self.cg_buf_num_splits = None @@ -75,7 +76,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor, 1, # MQA for the decode path ) - if self.runner.full_cuda_graph: + if self.compilation_config.full_cuda_graph: # First time around (CUDAGraph capture), allocate the static buffer if self.cg_buf_tile_scheduler_metadata is None: self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index d5f9dfaea06..42a04258361 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -8,6 +8,8 @@ import vllm.envs as envs from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd +from vllm.config import VllmConfig +from vllm.utils import cdiv # yapf conflicts with isort for this docstring # yapf: disable from vllm.v1.attention.backends.mla.common import (MLACommonBackend, @@ -16,7 +18,6 @@ MLACommonMetadata, MLACommonMetadataBuilder) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.block_table import BlockTable # yapf: enable @@ -65,24 +66,26 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): full_cudagraph_supported: ClassVar[bool] = True # decode only - def __init__(self, runner, kv_cache_spec: AttentionSpec, - block_table: BlockTable): - super().__init__(runner, kv_cache_spec, block_table, AiterMLAMetadata) + def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, + device: torch.device): + super().__init__(kv_cache_spec, vllm_config, device, AiterMLAMetadata) assert self.kv_cache_spec.block_size == 1, "AITER MLA" \ "only supports block size 1." + self.compilation_config = vllm_config.compilation_config + max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len, + self.kv_cache_spec.block_size) + max_num_reqs = vllm_config.scheduler_config.max_num_seqs + max_num_pages = max_num_reqs * max_num_pages_per_req + # Preparing persistent buffers - if self.runner.full_cuda_graph: - device = self.runner.device - max_num_reqs = self.runner.max_num_reqs + if vllm_config.compilation_config.full_cuda_graph: self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, dtype=torch.int32, device=device) - self.paged_kv_indices = torch.zeros( - block_table.get_device_tensor().numel( - ), # max num pages possible - dtype=torch.int32, - device=device) + self.paged_kv_indices = torch.zeros(max_num_pages, + dtype=torch.int32, + device=device) self.paged_kv_last_page_len = torch.zeros(max_num_reqs, dtype=torch.int32, device=device) @@ -96,7 +99,8 @@ def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: page_size = self.kv_cache_spec.block_size block_table_bounds = (seq_lens + page_size - 1) // page_size - device = self.runner.device + device = self.device + num_reqs = seq_lens.size(0) mask = (torch.arange(block_table_tensor.size(1), dtype=block_table_tensor.dtype, @@ -113,8 +117,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor, block_table_bounds.cumsum(dim=0, dtype=torch.int32) ]) - if self.runner.full_cuda_graph: - num_reqs = self._num_decodes + if self.compilation_config.full_cuda_graph: num_actual_pages = paged_kv_indices.size(0) @@ -137,7 +140,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor, else: qo_indptr = torch.arange(0, - self._num_decodes + 1, + num_reqs + 1, step=1, dtype=torch.int32, device=device) diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index dd86e56885e..46802bf5c2a 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with AiterFlashAttention.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import Any, Optional import torch @@ -10,18 +10,13 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType, is_quantized_kv_cache) +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import ( make_local_attention_virtual_batches) from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.block_table import BlockTable - -if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput - from vllm.v1.worker.gpu_input_batch import InputBatch - from vllm.v1.worker.gpu_model_runner import GPUModelRunner if current_platform.is_rocm(): import aiter @@ -172,54 +167,49 @@ def flash_attn_varlen_func_fake( class AiterFlashAttentionMetadataBuilder: - def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, - block_table: BlockTable): - model_config = runner.model_config - - self.runner = runner - self.num_heads_q = model_config.get_num_attention_heads( - runner.parallel_config) - self.num_heads_kv = model_config.get_num_kv_heads( - runner.parallel_config) - self.headdim = model_config.get_head_size() + def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, + device: torch.device): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.parallel_config = vllm_config.parallel_config + self.cache_config = vllm_config.cache_config + self.device = device + + self.num_heads_q = self.model_config.get_num_attention_heads( + self.parallel_config) + self.num_heads_kv = self.model_config.get_num_kv_heads( + self.parallel_config) + self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec - self.block_table = block_table # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. self.aot_sliding_window: Optional[tuple[int, int]] = None - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: + def reorder_batch(self, input_batch, scheduler_output) -> bool: return False - def build(self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata): + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> 'AiterFlashAttentionMetadata': - num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) - total_tokens = int(self.runner.seq_lens_np[:num_reqs].sum()) + max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + total_tokens = int(common_attn_metadata.seq_lens_cpu.sum()) query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens - block_table = self.block_table - block_table_tensor = block_table.get_device_tensor()[:num_reqs] - - block_table.slot_mapping[:num_actual_tokens].copy_( - block_table.slot_mapping_cpu[:num_actual_tokens], - non_blocking=True) - # Fill unused with -1. Needed for reshape_and_cache in full cuda graph - # mode. - block_table.slot_mapping[num_actual_tokens:].fill_(-1) - - slot_mapping = block_table.slot_mapping[:num_actual_tokens] + seq_lens_cpu = common_attn_metadata.seq_lens_cpu + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1, dtype=torch.int32, - device="cuda") + device=self.device) torch.cumsum(seq_lens, dim=0, dtype=cu_seq_lens.dtype, @@ -231,21 +221,21 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, # for local attention local_attn_metadata = None - if self.runner.attention_chunk_size is not None: + if self.model_config.attention_chunk_size is not None: seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \ virt_block_table_tensor = make_local_attention_virtual_batches( - self.runner.attention_chunk_size, - self.runner.query_start_loc_np[:num_reqs + 1], - self.runner.seq_lens_np[:num_reqs], + self.model_config.attention_chunk_size, + query_start_loc_cpu.numpy(), + seq_lens_cpu.numpy(), block_table_tensor, self.block_size, ) local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to( - self.runner.device, non_blocking=True) + self.device, non_blocking=True) local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to( - self.runner.device, non_blocking=True) - local_max_query_len = int(seqlens_q_local_np.max()) - local_max_seq_len = int(virt_k_seqlens_np.max()) + self.device, non_blocking=True) + local_max_query_len = seqlens_q_local_np.max().item() + local_max_seq_len = virt_k_seqlens_np.max().item() local_scheduler_metadata = schedule( batch_size=local_query_start_loc.shape[0] - 1, cu_query_lens=local_query_start_loc, @@ -256,12 +246,11 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1, dtype=torch.int32, - device=self.runner.device) + device=self.device) local_cu_seq_lens[1:] = torch.cumsum( - torch.from_numpy(virt_k_seqlens_np).to( - device=self.runner.device, - dtype=torch.int32, - non_blocking=True), + torch.from_numpy(virt_k_seqlens_np).to(device=self.device, + dtype=torch.int32, + non_blocking=True), dim=0) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 7dc90a6a97e..ee95b5af6e4 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with PagedAttention and Triton prefix prefill.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Optional +from typing import Any, ClassVar, Optional import torch @@ -14,6 +14,7 @@ chunked_prefill_paged_decode) from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.triton_unified_attention import unified_attention +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata @@ -21,10 +22,6 @@ AttentionMetadataBuilder, CommonAttentionMetadata, make_local_attention_virtual_batches) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.block_table import BlockTable - -if TYPE_CHECKING: - from vllm.v1.worker.gpu_model_runner import GPUModelRunner logger = init_logger(__name__) @@ -75,12 +72,21 @@ class TritonAttentionMetadataBuilder( AttentionMetadataBuilder[TritonAttentionMetadata]): full_cudagraph_supported: ClassVar[bool] = True - def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, - block_table: BlockTable): - self.runner = runner + def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, + device: torch.device): + self.device = device self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec - self.block_table = block_table + + model_config = vllm_config.model_config + self.num_heads_q = model_config.get_num_attention_heads( + vllm_config.parallel_config) + self.num_heads_kv = model_config.get_num_kv_heads( + vllm_config.parallel_config) + self.headdim = model_config.get_head_size() + + self.attention_chunk_size = getattr(vllm_config.scheduler_config, + 'attention_chunk_size', None) def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata @@ -92,46 +98,36 @@ def build_for_cudagraph_capture( attn_metadata.seq_lens.fill_(1) return attn_metadata - def build( - self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata - ) -> TritonAttentionMetadata: - num_reqs = common_attn_metadata.num_reqs + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> TritonAttentionMetadata: num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) + max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens - block_table = self.block_table - block_table_tensor = block_table.get_device_tensor()[:num_reqs] - - block_table.slot_mapping[:num_actual_tokens].copy_( - block_table.slot_mapping_cpu[:num_actual_tokens], - non_blocking=True) - # Fill unused with -1. Needed for reshape_and_cache in full cuda graph - # mode. - block_table.slot_mapping[num_actual_tokens:].fill_(-1) - - slot_mapping = block_table.slot_mapping[:num_actual_tokens] + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping # for local attention local_attn_metadata = None - if self.runner.attention_chunk_size is not None: + if self.attention_chunk_size is not None: seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \ virt_block_table_tensor = make_local_attention_virtual_batches( - self.runner.attention_chunk_size, - self.runner.query_start_loc_np[:num_reqs + 1], - self.runner.seq_lens_np[:num_reqs], + self.attention_chunk_size, + common_attn_metadata.query_start_loc_cpu.numpy(), + common_attn_metadata.seq_lens_cpu.numpy(), block_table_tensor, self.block_size, ) local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to( - self.runner.device, non_blocking=True) + self.device, non_blocking=True) local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to( - self.runner.device, non_blocking=True) - local_max_query_len = seqlens_q_local_np.max() - local_max_seq_len = virt_k_seqlens_np.max() + self.device, non_blocking=True) + local_max_query_len = seqlens_q_local_np.max().item() + local_max_seq_len = virt_k_seqlens_np.max().item() local_attn_metadata = TritonAttentionMetadata \ .LocalAttentionMetadata( @@ -148,14 +144,13 @@ def build( if use_cascade: cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], dtype=torch.int32, - device=self.runner.device) + device=self.device) prefix_kv_lens = torch.tensor([common_prefix_len], dtype=torch.int32, - device=self.runner.device) - suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] - + device=self.device) + suffix_kv_lens = (common_attn_metadata.seq_lens_cpu - common_prefix_len) - suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to( - self.runner.device) + suffix_kv_lens = suffix_kv_lens.to(self.device) else: cu_prefix_query_lens = None prefix_kv_lens = None diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 88adc32406e..db6eaa55864 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -22,6 +22,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import ( get_kv_connector_cache_layout) from vllm.logger import init_logger +from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) _KV_CACHE_LAYOUT_OVERRIDE = None @@ -32,14 +33,22 @@ class CommonAttentionMetadata: """ Per-batch attention metadata, shared across layers and backends. AttentionMetadataBuilder instances use it to construct per-layer metadata. + + For many of the tensors we keep both GPU and CPU versions. """ query_start_loc: torch.Tensor + query_start_loc_cpu: torch.Tensor """(batch_size + 1,), the start location of each request in query Tensor""" + seq_lens: torch.Tensor + seq_lens_cpu: torch.Tensor """(batch_size,), the length of each request including both computed tokens and newly scheduled tokens""" + num_computed_tokens_cpu: torch.Tensor + """(batch_size,), the number of computed tokens for each request""" + num_reqs: int """Number of requests""" num_actual_tokens: int @@ -47,6 +56,14 @@ class CommonAttentionMetadata: max_query_len: int """Longest query in batch""" + block_table_tensor: torch.Tensor + slot_mapping: torch.Tensor + + def __post_init__(self): + # Fill unused with -1. Needed for reshape_and_cache in full cuda graph + # mode. + self.slot_mapping[self.num_actual_tokens:].fill_(-1) + M = TypeVar("M") @@ -56,11 +73,25 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): full_cudagraph_supported: ClassVar[bool] = False @abstractmethod - def build(self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata) -> M: + def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, + device: torch.device): + self.kv_cache_spec = kv_cache_spec + + @abstractmethod + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> M: """ Central method that builds attention metadata. Some builders (MLA) require reorder_batch to be called prior to build. + + Args: + common_prefix_len: The length of the common prefix of the batch. + common_attn_metadata: The common attention metadata. + fast_build: The meta-data will prioritize speed of building over + then speed at execution. Can be used for spec-decode where the + result of a build call may only be used for few layers/iters. """ raise NotImplementedError @@ -351,3 +382,108 @@ def make_local_attention_virtual_batches( return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \ block_table_local + + +def split_decodes_and_prefills( + common_attn_metadata: CommonAttentionMetadata, + decode_threshold: int = 1, +) -> tuple[int, int, int, int]: + """ + Assuming a reordered batch, finds the boundary between prefill and decode + requests. + + Args: + common_attn_metadata: CommonAttentionMetadata object containing the + batch metadata. + decode_threshold: The maximum query length to be considered a decode. + + Returns: + num_decodes: The number of decode requests. + num_prefills: The number of prefill requests. + num_decode_tokens: The number of tokens in the decode requests. + num_prefill_tokens: The number of tokens in the prefill requests. + """ + max_query_len = common_attn_metadata.max_query_len + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens + query_start_loc = common_attn_metadata.query_start_loc_cpu + + if max_query_len <= decode_threshold: + return num_reqs, 0, num_tokens, 0 + + query_lens = query_start_loc[1:] - query_start_loc[:-1] + is_prefill = query_lens > decode_threshold + if not torch.any(is_prefill): + return num_reqs, 0, num_tokens, 0 + + first_prefill = is_prefill.int().argmax(dim=-1).item() + assert torch.all(query_lens[first_prefill:] > decode_threshold) + assert torch.all(query_lens[:first_prefill] <= decode_threshold) + num_decodes = first_prefill + num_prefills = num_reqs - num_decodes + num_decode_tokens = query_start_loc[first_prefill].item() + num_prefill_tokens = num_tokens - num_decode_tokens + return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) + + +def reorder_batch_to_split_decodes_and_prefills( + input_batch: "InputBatch", + scheduler_output: "SchedulerOutput", + decode_threshold: int = 1, +) -> bool: + """ + Reorders the batch to split into prefill and decode requests; places all + requests with <= decode_threshold tokens at the front of the batch. + + Returns: + True if the batch was modified, False otherwise. + """ + # We now want to reorder the batch so that the "decode" requests are at + # the front and the "prefill" requests are at the back using the least + # amount of swaps possible. (NOTE for now we loosely use "decode" to mean + # requests where attention is likely memory-bound and "prefill" to mean + # requests where attention is likely compute-bound, TODO(lucas): figure out + # a better naming here) + decodes = [] + prefills = [] + num_decode_tokens = 0 + num_prefill_tokens = 0 + + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + # for now treat 1 scheduled token as "decode" even if its not, + # we should update this to something like < 8 in the future but + # currently the TritonMLA._forward_decode only supports + # num_tokens = 1 + if num_tokens <= decode_threshold: + decodes.append(i) + num_decode_tokens += num_tokens + else: + prefills.append(i) + num_prefill_tokens += num_tokens + + # We hope that this is fairly minimal since decodes + # should be around for a number of iterations so hopefully they are + # relatively stationary (and new request are generally appended to the + # persistent batch so already should be at the back) + # To achieve this we loop over the decodes in descending order and + # the prefills in ascending order. We swap decodes from the "back" + # i.e. past where the last decode should be in the reodorered with + # prefills from the front of the batch. + # `decodes` and `prefills` are already in ascending order just based on + # the above loop + num_decodes = len(decodes) + num_prefills = len(prefills) + modified_batch = False + + for i in range(1, min(num_decodes, num_prefills) + 1): + # If the decode is at the "back" of the batch, i, we can swap it + # with the prefill closest to the front of the batch + decode_idx = decodes[num_decodes - i] + if decode_idx < num_decodes: + break + + input_batch.swap_states(prefills[i - 1], decode_idx) + modified_batch = True + + return modified_batch diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 6661d984a77..967847c02ff 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import numpy as np import torch import torch.nn as nn @@ -12,11 +13,11 @@ from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.utils import is_pin_memory_available from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel logger = init_logger(__name__) @@ -37,7 +38,6 @@ def __init__( self.method = self.speculative_config.method self.runner = runner - self.dtype = vllm_config.model_config.dtype self.max_model_len = vllm_config.model_config.max_model_len self.block_size = vllm_config.cache_config.block_size @@ -45,6 +45,7 @@ def __init__( self.speculative_config.num_speculative_tokens) self.max_num_tokens = ( vllm_config.scheduler_config.max_num_batched_tokens) + self.token_arange_np = np.arange(self.max_num_tokens) # We need to get the hidden size from the draft model config because # the draft model's hidden size can be different from the target model's # hidden size (e.g., Llama 3.3 70B). @@ -83,19 +84,14 @@ def propose( target_positions: torch.Tensor, # [num_tokens, hidden_size] target_hidden_states: torch.Tensor, - # [num_tokens] - target_slot_mapping: torch.Tensor, # [batch_size] next_token_ids: torch.Tensor, - # [batch_size + 1] starting with 0 - cu_num_tokens: torch.Tensor, - # [batch_size, max_num_blocks_per_req] - block_table: torch.Tensor, + common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] - last_token_indices = cu_num_tokens[1:] - 1 + last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 if self.method == "eagle3": assert isinstance(self.model, Eagle3LlamaForCausalLM) @@ -110,50 +106,14 @@ def propose( # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] self.input_ids[last_token_indices] = next_token_ids - # FA requires seq_len to have dtype int32. - seq_lens = (target_positions[last_token_indices] + 1).int() - - if self.method in ["eagle", "eagle3"]: - # FIXME(woosuk): The below two ops cause synchronization. Optimize. - max_seq_len = seq_lens.max().item() - max_num_tokens = (cu_num_tokens[1:] - - cu_num_tokens[:-1]).max().item() - attn_metadata = FlashAttentionMetadata( - num_actual_tokens=num_tokens, - max_query_len=max_num_tokens, - query_start_loc=cu_num_tokens, - max_seq_len=max_seq_len, - seq_lens=seq_lens, - block_table=block_table, - slot_mapping=target_slot_mapping, - # TODO(woosuk): Support cascade attention. - use_cascade=False, - common_prefix_len=0, - cu_prefix_query_lens=None, - prefix_kv_lens=None, - suffix_kv_lens=None, - ) - elif self.method == "deepseek_mtp": - query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] - max_query_len = query_lens.max().item() - - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=cu_num_tokens, - seq_lens=seq_lens, - num_reqs=batch_size, - num_actual_tokens=num_tokens, - max_query_len=max_query_len, - ) - - assert self.runner is not None + assert self.runner is not None - # FIXME: need to consider multiple kv_cache_groups - attn_metadata = self.runner.attn_metadata_builders[0].build( - common_prefix_len=0, - common_attn_metadata=common_attn_metadata, - ) - else: - raise ValueError(f"Unsupported method: {self.method}") + # FIXME: need to consider multiple kv_cache_groups + attn_metadata = self.runner.attn_metadata_builders[0].build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + fast_build=True, + ) # At this moment, we assume all eagle layers belong to the same KV # cache group, thus using the same attention metadata. @@ -194,6 +154,11 @@ def propose( # one layer. Adapt this code to support multiple layers once # there's a multi-layer MTP module. + # Currently FlashAttention is the only backend that supports + # multi-token eagle spec decode. This is because the code below + # makes assumptions about attn_metadata attributes available. + assert isinstance(attn_metadata, FlashAttentionMetadata) + # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] @@ -238,8 +203,8 @@ def propose( # Compute the slot mapping. block_numbers = clamped_positions // self.block_size - block_ids = block_table.gather(dim=1, - index=block_numbers.view(-1, 1)) + block_ids = attn_metadata.block_table.gather( + dim=1, index=block_numbers.view(-1, 1)) block_ids = block_ids.view(-1) attn_metadata.slot_mapping = (block_ids * self.block_size + clamped_positions % self.block_size) @@ -275,46 +240,99 @@ def propose( draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids - @staticmethod def prepare_inputs( - # [batch_size + 1] - cu_target_query_lens: torch.Tensor, + self, + common_attn_metadata: CommonAttentionMetadata, # [batch_size] - num_rejected_tokens: torch.Tensor, - num_tokens: int, - ) -> tuple[torch.Tensor, torch.Tensor]: - # cu_target_query_lens: [0, a, a + b, a + b + c] - # num_rejected_tokens: [n1, n2, n3] - # num_tokens_per_req: [a - n1, b - n2, c - n3] - # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] - # token_indices: [0, 1, ..., a - n1 - 1, - # a, a + 1, ..., a + b - n2 - 1, - # a + b, a + b + 1, ..., a + b + c - n3 - 1] - - # [0, a, a + b, a + b + c] -> [a, b, c] - query_len_per_req = (cu_target_query_lens[1:] - - cu_target_query_lens[:-1]) - # [a, b, c] -> [a - n1, b - n2, c - n3] - num_tokens_per_req = query_len_per_req - num_rejected_tokens - - # [a - n1, b - n2, c - n3] -> - # [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] - cu_num_tokens = torch.zeros_like(cu_target_query_lens) - torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) - token_indices = torch.empty( - num_tokens, + num_rejected_tokens: torch.Tensor + ) -> tuple[CommonAttentionMetadata, torch.Tensor]: + """ + This function is used to prepare the inputs for the spec decode. + It updates to the common_attn_metadata to account for the rejected + tokens (and newly sampled tokens). It also returns the token indices + of the tokens that should be fed to the speculator. + """ + # E.g. + # common_attn_metadata.query_start_loc{_cpu}: + # [0, q1, q1 + q2, q1 + q2 + q3] + # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3] + # num_rejected_tokens: [n1, n2, n3] + # This function computes the intermediate values: + # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3] + # And returns: + # common_attn_metadata.query_start_loc{_cpu}: + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] + # common_attn_metadata.seq_lens{_cpu}: + # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] + # token_indices: [0, 1, ..., q1 - n1 - 1, + # q1, q1 + 1, ..., q1 + q2 - n2 - 1, + # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] + + device = common_attn_metadata.query_start_loc.device + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \ + - num_rejected_tokens + + # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] + new_query_len_per_req = (query_start_loc_cpu[1:] - + query_start_loc_cpu[:-1]) + # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3] + new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens + new_num_tokens_per_req_np = new_num_tokens_per_req.numpy() + + # [q1 - n1, q2 - n2, q3 - n3] -> + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] + new_query_start_loc_cpu = torch.zeros( + query_start_loc_cpu.shape, dtype=torch.int32, - device=cu_target_query_lens.device, - ) - batch_size = num_rejected_tokens.shape[0] - BLOCK_SIZE = 1024 - prepare_eagle_input_kernel[(batch_size, )]( - token_indices, - cu_target_query_lens, - cu_num_tokens, - BLOCK_SIZE=BLOCK_SIZE, + pin_memory=is_pin_memory_available()) + new_query_start_loc_np = new_query_start_loc_cpu.numpy() + np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:]) + + total_num_tokens = new_query_start_loc_np[-1] + # Example assuming num_tokens_per_req_np = [2, 4, 3] + # this implies that `new_query_start_locs` is: + # [0, 2, 6, 9] -> + # [0, 0, 2, 2, 2, 2, 6, 6, 6] + # _r1_ ____r2____ ___r3__ + new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], + new_num_tokens_per_req_np) + # [0, 1, 2, 3, 4, 5, 6, 7, 8] -> + # [0, 1, 0, 1, 2, 3, 0, 1, 2] + # _r1_ ____r2____ ___r3__ + token_offests = self.token_arange_np[:total_num_tokens] \ + - new_query_start_locs_expanded + + # Expand starting positions to match token pattern + # [0, q1, q1 + q2] -> + # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2] + # _r1_ _____r2_______ ___________r3____________ + old_query_start_locs_expanded = np.repeat( + query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) + # Final token indices are: + # [0, 1, // req 1 + # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 + # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 + token_indices_np = token_offests + old_query_start_locs_expanded + token_indices = torch.from_numpy(token_indices_np).to( + device, non_blocking=True) + + spec_common_attn_metadata = CommonAttentionMetadata( + query_start_loc=new_query_start_loc_cpu.to(device, + non_blocking=True), + seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), + query_start_loc_cpu=new_query_start_loc_cpu, + seq_lens_cpu=new_seq_lens_cpu, + num_computed_tokens_cpu=common_attn_metadata. + num_computed_tokens_cpu, + num_reqs=common_attn_metadata.num_reqs, + num_actual_tokens=total_num_tokens, + max_query_len=new_query_len_per_req.max().item(), + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping[token_indices], ) - return cu_num_tokens, token_indices + + return spec_common_attn_metadata, token_indices def load_model(self, target_model: nn.Module) -> None: draft_model_config = \ diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index 3a86fea146f..1116179dc5b 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.sampling_params import SamplingParams -from vllm.triton_utils import tl, triton _SAMPLING_EPS = 1e-5 @@ -13,29 +12,3 @@ def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool: or sampling_params.repetition_penalty != 1.0 or sampling_params.min_p > _SAMPLING_EPS or sampling_params.logprobs is not None) - - -@triton.jit -def prepare_eagle_input_kernel( - out_ptr, - cu_query_lens_ptr, - cu_num_tokens_ptr, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) - - # [start_pos, end_pos) - start_pos = tl.load(cu_num_tokens_ptr + pid) - end_pos = tl.load(cu_num_tokens_ptr + pid + 1) - num_tokens = end_pos - start_pos - - index_start = tl.load(cu_query_lens_ptr + pid) - - num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE) - for i in tl.range(num_blocks): - offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - tl.store( - out_ptr + start_pos + offset, - index_start + offset, - mask=offset < num_tokens, - ) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 8f4e8d64c61..bf38e88f0c2 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -14,12 +14,14 @@ class BlockTable: def __init__( self, + block_size: int, max_num_reqs: int, max_num_blocks_per_req: int, max_num_batched_tokens: int, pin_memory: bool, device: torch.device, ): + self.block_size = block_size self.max_num_reqs = max_num_reqs self.max_num_blocks_per_req = max_num_blocks_per_req self.max_num_batched_tokens = max_num_batched_tokens @@ -79,10 +81,31 @@ def swap_row(self, src: int, tgt: int) -> None: self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]] - def commit(self, num_reqs: int) -> None: + def compute_slot_mapping(self, req_indices: np.ndarray, + positions: np.ndarray) -> None: + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` + # here because M (max_model_len) is not necessarily divisible by + # block_size. + block_table_indices = (req_indices * self.max_num_blocks_per_req + + positions // self.block_size) + block_table_cpu = self.get_cpu_tensor() + block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() + block_offsets = positions % self.block_size + np.add(block_numbers * self.block_size, + block_offsets, + out=self.slot_mapping_np[:req_indices.shape[0]]) + + def commit_block_table(self, num_reqs: int) -> None: self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs], non_blocking=True) + def commit_slot_mapping(self, num_tokens: int) -> None: + self.slot_mapping[:num_tokens].copy_( + self.slot_mapping_cpu[:num_tokens], non_blocking=True) + def clear(self) -> None: self.block_table.fill_(0) self.block_table_cpu.fill_(0) @@ -107,7 +130,8 @@ def __init__(self, max_num_reqs: int, max_model_len: int, max_num_batched_tokens: int, pin_memory: bool, device: torch.device, block_sizes: list[int]) -> None: self.block_tables = [ - BlockTable(max_num_reqs, cdiv(max_model_len, block_size), + BlockTable(block_size, max_num_reqs, cdiv(max_model_len, + block_size), max_num_batched_tokens, pin_memory, device) for block_size in block_sizes ] @@ -129,9 +153,18 @@ def swap_row(self, src: int, tgt: int) -> None: for block_table in self.block_tables: block_table.swap_row(src, tgt) - def commit(self, num_reqs: int) -> None: + def compute_slot_mapping(self, req_indices: np.ndarray, + positions: np.ndarray) -> None: + for block_table in self.block_tables: + block_table.compute_slot_mapping(req_indices, positions) + + def commit_block_table(self, num_reqs: int) -> None: + for block_table in self.block_tables: + block_table.commit_block_table(num_reqs) + + def commit_slot_mapping(self, num_tokens: int) -> None: for block_table in self.block_tables: - block_table.commit(num_reqs) + block_table.commit_slot_mapping(num_tokens) def clear(self) -> None: for block_table in self.block_tables: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index af216539c90..29f519393e4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3,7 +3,6 @@ import gc import time -import weakref from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Optional, Union @@ -42,8 +41,7 @@ from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, LazyLoader, async_tensor_h2d, - check_use_alibi, get_dtype_size, + GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up) from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, @@ -62,7 +60,6 @@ from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer -from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -577,8 +574,9 @@ def _get_cumsum_and_arange( def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> tuple[dict[str, Any], bool, torch.Tensor, - Optional[SpecDecodeMetadata], np.ndarray]: + ) -> tuple[dict[str, + Any], bool, torch.Tensor, Optional[SpecDecodeMetadata], + np.ndarray, Optional[CommonAttentionMetadata]]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, @@ -593,7 +591,7 @@ def _prepare_inputs( # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. - self.input_batch.block_table.commit(num_reqs) + self.input_batch.block_table.commit_block_table(num_reqs) # Get the number of scheduled tokens for each request. req_ids = self.input_batch.req_ids @@ -637,29 +635,10 @@ def _prepare_inputs( torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) - # Calculate the slot mapping for each KV cache group. - for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): - block_size = kv_cache_group_spec.kv_cache_spec.block_size - block_table: BlockTable = self.input_batch.block_table[ - kv_cache_group_id] - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] - # where K is the max_num_blocks_per_req and the block size is 2. - # NOTE(woosuk): We can't simply use `token_indices // block_size` - # here because M (max_model_len) is not necessarily divisible by - # block_size. - block_table_indices = ( - req_indices * block_table.max_num_blocks_per_req + - positions_np // block_size) - block_table_cpu = block_table.get_cpu_tensor() - block_numbers = block_table_cpu.flatten( - )[block_table_indices].numpy() - block_offsets = positions_np % block_size - np.add( - block_numbers * block_size, - block_offsets, - out=block_table.slot_mapping_np[:total_num_scheduled_tokens]) + self.input_batch.block_table.compute_slot_mapping( + req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping( + total_num_scheduled_tokens) # Prepare the attention metadata. self.query_start_loc_np[0] = 0 @@ -696,15 +675,8 @@ def _prepare_inputs( self.query_start_loc_cpu[num_reqs].item()) query_start_loc = self.query_start_loc[:num_reqs + 1] - seq_lens = self.seq_lens[:num_reqs] - - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=query_start_loc, - seq_lens=seq_lens, - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - ) + + spec_decode_common_attn_metadata = None attn_metadata: dict[str, Any] = {} # Prepare the attention metadata for each KV cache group and make layers @@ -712,6 +684,27 @@ def _prepare_inputs( for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): + blk_table = self.input_batch.block_table[kv_cache_group_id] + blk_table_tensor = blk_table.get_device_tensor()[:num_reqs] + slot_mapping = blk_table.slot_mapping[:total_num_scheduled_tokens] + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=self.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], + seq_lens=self.seq_lens[:num_reqs], + seq_lens_cpu=self.seq_lens_cpu[:num_reqs], + num_computed_tokens_cpu=self.input_batch. + num_computed_tokens_cpu_tensor[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + block_table_tensor=blk_table_tensor, + slot_mapping=slot_mapping, + ) + + if self.speculative_config and \ + spec_decode_common_attn_metadata is None: + spec_decode_common_attn_metadata = common_attn_metadata + # Prepare for cascade attention if enabled & beneficial. common_prefix_len = 0 builder = self.attn_metadata_builders[kv_cache_group_id] @@ -765,7 +758,8 @@ def _prepare_inputs( self.set_active_loras(self.input_batch, num_scheduled_tokens) return (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, num_scheduled_tokens) + spec_decode_metadata, num_scheduled_tokens, + spec_decode_common_attn_metadata) def _compute_cascade_attn_prefix_len( self, @@ -1286,8 +1280,9 @@ def execute_model( # Prepare the decoder inputs. (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, - num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output)) + spec_decode_metadata, num_scheduled_tokens_np, + spec_decode_common_attn_metadata) = ( + self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): @@ -1528,6 +1523,7 @@ def execute_model( # Speculative decoding is not enabled. spec_token_ids = None else: + assert spec_decode_common_attn_metadata is not None spec_token_ids = self.propose_draft_token_ids( scheduler_output, valid_sampled_token_ids, @@ -1536,7 +1532,7 @@ def execute_model( sample_hidden_states, aux_hidden_states, spec_decode_metadata, - attn_metadata, + spec_decode_common_attn_metadata, ) self.eplb_step() @@ -1561,7 +1557,7 @@ def propose_draft_token_ids( sample_hidden_states: torch.Tensor, aux_hidden_states: Optional[torch.Tensor], spec_decode_metadata: Optional[SpecDecodeMetadata], - attn_metadata: dict[str, Any], + common_attn_metadata: CommonAttentionMetadata, ) -> list[list[int]]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": @@ -1608,16 +1604,6 @@ def propose_draft_token_ids( next_token_ids = torch.tensor(next_token_ids, dtype=torch.int32, device=self.device) - # At this moment, we assume all eagle layers belong to the same KV - # cache group, thus using the same attention metadata. - eagle_attn_metadata = attn_metadata[ - self.drafter.attn_layer_names[0]] - - # NOTE: deepseek_mtp uses MLA which does not have `block_table` - if hasattr(eagle_attn_metadata, "block_table"): - block_table = eagle_attn_metadata.block_table - else: - block_table = None if spec_decode_metadata is None: # input_ids can be None for multimodal models. @@ -1630,8 +1616,6 @@ def propose_draft_token_ids( dim=-1) else: target_hidden_states = hidden_states[:num_scheduled_tokens] - target_slot_mapping = eagle_attn_metadata.slot_mapping - cu_num_tokens = eagle_attn_metadata.query_start_loc else: # TODO(woosuk): Refactor this. num_draft_tokens = spec_decode_metadata.num_draft_tokens @@ -1639,17 +1623,12 @@ def propose_draft_token_ids( n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 for i, n in enumerate(num_draft_tokens) ] - num_rejected_tokens_tensor = async_tensor_h2d( - num_rejected_tokens, - dtype=torch.int32, - target_device=self.device, - pin_memory=True) - num_tokens = num_scheduled_tokens - sum(num_rejected_tokens) - cu_num_tokens, token_indices = self.drafter.prepare_inputs( - eagle_attn_metadata.query_start_loc, - num_rejected_tokens_tensor, - num_tokens, - ) + num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens, + dtype=torch.int32) + common_attn_metadata, token_indices =\ + self.drafter.prepare_inputs( + common_attn_metadata, num_rejected_tokens_cpu) + target_token_ids = self.input_ids[token_indices] # TODO(woosuk): Support M-RoPE. target_positions = self.positions[token_indices] @@ -1658,17 +1637,13 @@ def propose_draft_token_ids( [h[token_indices] for h in aux_hidden_states], dim=-1) else: target_hidden_states = hidden_states[token_indices] - target_slot_mapping = eagle_attn_metadata.slot_mapping[ - token_indices] draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, - target_slot_mapping=target_slot_mapping, next_token_ids=next_token_ids, - cu_num_tokens=cu_num_tokens, - block_table=block_table, sampling_metadata=sampling_metadata, + common_attn_metadata=common_attn_metadata, ) spec_token_ids = draft_token_ids.tolist() return spec_token_ids @@ -1970,24 +1945,29 @@ def _dummy_run( if capture_attn_cudagraph: attn_metadata = {} - query_start_loc = self.query_start_loc[:num_reqs + 1] # Make sure max_model_len is used at the graph capture time. self.seq_lens_np[:num_reqs] = self.max_model_len self.seq_lens_np[num_reqs:] = 0 self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], non_blocking=True) - seq_lens = self.seq_lens[:num_reqs] - - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=query_start_loc, - seq_lens=seq_lens, - num_reqs=num_reqs, - num_actual_tokens=num_tokens, - max_query_len=num_tokens, - ) for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=self.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + + 1], + seq_lens=self.seq_lens[:num_reqs], + seq_lens_cpu=self.seq_lens_cpu[:num_reqs], + num_computed_tokens_cpu=self.input_batch. + num_computed_tokens_cpu_tensor[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + max_query_len=num_tokens, + block_table_tensor=self.input_batch.block_table[ + kv_cache_group_id].get_device_tensor()[:num_reqs], + slot_mapping=self.input_batch. + block_table[kv_cache_group_id].slot_mapping[:num_reqs]) attn_metadata_i = self.attn_metadata_builders[ kv_cache_group_id].build_for_cudagraph_capture( @@ -2339,11 +2319,10 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: raise ValueError( f"Unknown KV cache spec type: {type(kv_cache_spec)}") - block_table_i = self.input_batch.block_table[i] attn_metadata_builder_i = attn_backend_i.get_builder_cls()( - weakref.proxy(self), kv_cache_spec, - block_table_i, + self.vllm_config, + self.device, ) if (self.full_cuda_graph