diff --git a/pyproject.toml b/pyproject.toml index e8c2403af06..5926ab8c862 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ ignore_patterns = [ [tool.ruff] # Allow lines to be as long as 80. -line-length = 80 +line-length = 90 [tool.ruff.lint.per-file-ignores] "vllm/third_party/**" = ["ALL"] diff --git a/tests/spec_decode/test_tree_attention.py b/tests/spec_decode/test_tree_attention.py new file mode 100644 index 00000000000..94b4d9da126 --- /dev/null +++ b/tests/spec_decode/test_tree_attention.py @@ -0,0 +1,209 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math + +import torch +from xformers.ops.fmha.attn_bias import PagedBlockDiagonalPaddedKeysMask + +from vllm.attention.backends.abstract import AttentionBackend +from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend +from vllm.v1.attention.backends.tree_attn import TreeAttentionBackend + + +class NoOpLayerModule(torch.nn.Module): + _q_scale = torch.tensor(1.0, dtype=torch.float32) + _k_scale = torch.tensor(1.0, dtype=torch.float32) + _v_scale = torch.tensor(1.0, dtype=torch.float32) + + def __init__(self): + super().__init__() + + def forward(self, x): + return x + + +def forward_attention( + batch_size: int, + num_heads: int, + num_kv_heads: int, + dim_per_head: int, + block_size: int, + max_sequence_length: int, + sequence_position: int, + q_len: int, + backends: list[type[AttentionBackend]], +) -> list[torch.Tensor]: + # Assert that the number of heads is divisible by the number of KV heads. + assert num_heads % num_kv_heads == 0 + + device = "cuda" + # Initialize q, k, and v. + q = torch.randn( + (batch_size * q_len, num_heads, dim_per_head), + device=device, + dtype=torch.bfloat16, + ) + k = torch.randn( + (batch_size * q_len, num_kv_heads, dim_per_head), + device=device, + dtype=torch.bfloat16, + ) + v = torch.randn( + (batch_size * q_len, num_kv_heads, dim_per_head), + device=device, + dtype=torch.bfloat16, + ) + + # Initialize the query and KV sequence lengths. + cu_seqlens_q = q_len * torch.arange( + batch_size + 1, device=device, dtype=torch.int32) + seqlens_q = torch.diff(cu_seqlens_q) + seqlens_kv = torch.full( + (batch_size, ), + sequence_position + q_len, + device=device, + dtype=torch.int32, + ) + max_seqlen_q = q_len + max_seqlen_k = sequence_position + q_len + num_actual_tokens = cu_seqlens_q[-1] + + # Setup the block table and KV cache for paged KV. + assert max_sequence_length % block_size == 0 + max_block_count_per_batch = max_sequence_length // block_size + kv_cache = torch.randn( + ( + 2, + batch_size * max_block_count_per_batch, + block_size, + num_kv_heads, + dim_per_head, + ), + device=device, + dtype=torch.bfloat16, + ) + num_allocated_blocks_per_batch = math.ceil(max_seqlen_k / block_size) + block_table = torch.zeros( + (batch_size, max_block_count_per_batch), + device=device, + dtype=torch.int32, + ) + block_ids = torch.arange( + 0, + batch_size * num_allocated_blocks_per_batch, + device=device, + dtype=torch.int32, + ).view(-1, num_allocated_blocks_per_batch) + block_table[:, :num_allocated_blocks_per_batch] = block_ids + + # Setup the slot mapping for the input KVs. + slots_per_batch = [] + for i in range(batch_size): + start_offset = block_ids[i, 0] * block_size + sequence_position + slots_per_batch.append( + torch.arange( + start_offset, + start_offset + q_len, + device=device, + dtype=torch.int64, + )) + slot_mapping = torch.cat(slots_per_batch, dim=0) + + softmax_scale = q.shape[-1]**(-0.5) + layer = NoOpLayerModule() + + # Run attention for each backend and collect the outputs. + outputs = [] + for backend_cls in backends: + # Set common metadata. + attn_metadata_dict = { + "num_actual_tokens": num_actual_tokens, + "max_query_len": max_seqlen_q, + "query_start_loc": cu_seqlens_q, + "max_seq_len": max_seqlen_k, + "seq_lens": seqlens_kv, + "block_table": block_table, + "slot_mapping": slot_mapping, + } + + # Set backend-specific metadata. + if backend_cls == FlashAttentionBackend: + attn_metadata_dict["use_cascade"] = False + attn_metadata_dict["common_prefix_len"] = 0 + attn_metadata_dict["cu_prefix_query_lens"] = None + attn_metadata_dict["prefix_kv_lens"] = None + attn_metadata_dict["suffix_kv_lens"] = None + elif backend_cls == TreeAttentionBackend: + # Construct the prefix bias. + prefix_kv_seqlens = seqlens_kv - seqlens_q + prefix_attn_bias = PagedBlockDiagonalPaddedKeysMask.from_seqlens( + q_seqlen=seqlens_q.tolist(), + kv_seqlen=prefix_kv_seqlens.tolist(), + page_size=block_size, + block_tables=block_table, + device=device, + ) + attn_metadata_dict["prefix_attn_bias"] = prefix_attn_bias + # Create a chain attn bias. + chain_attn_bias = torch.triu( + torch.full((q_len, q_len), + float("-inf"), + device=device, + dtype=torch.bfloat16), + diagonal=1, + ) + attn_metadata_dict["spec_attn_bias"] = chain_attn_bias + attn_metadata_dict["prefill_attn_metadata"] = None + + # Initialize the backend implementation. + instance = backend_cls.get_impl_cls()( + num_heads=num_heads, + head_size=dim_per_head, + scale=softmax_scale, + num_kv_heads=num_kv_heads, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="auto", + ) + + # Run forward pass and store output. + output = torch.empty_like(q) + outputs.append( + instance.forward( + layer=layer, + query=q, + key=k, + value=v, + kv_cache=kv_cache.clone(), + attn_metadata=backend_cls.get_metadata_cls()( + **attn_metadata_dict), + output=output, + )) + return outputs + + +def test_tree_attn_correctness() -> None: + torch.cuda.manual_seed_all(0) + + for batch_size in [1, 2, 16, 32, 64]: + for num_heads in [2, 4]: + for sequence_position in [16, 1024, 2048]: + for q_len in [1, 3, 7]: + flash_attn_output, tree_attn_output = forward_attention( + batch_size=batch_size, + num_heads=num_heads, + num_kv_heads=2, + dim_per_head=128, + block_size=128, + max_sequence_length=8192, + sequence_position=sequence_position, + q_len=q_len, + backends=[FlashAttentionBackend, TreeAttentionBackend], + ) + assert torch.allclose( + flash_attn_output, tree_attn_output, atol=7.81e-3 + ), (f"outputs are not close for batch_size: {batch_size}, " + f"num_heads: {num_heads}, " + f"sequence_position: {sequence_position}, " + f"q_len: {q_len}.") diff --git a/vllm/config.py b/vllm/config.py index 84aa14b7c86..6d251e652c5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2724,6 +2724,13 @@ def __post_init__(self): f"num_speculative_tokens:{self.num_speculative_tokens}" f" must be divisible by {n_predict=}") + if self.speculative_token_tree is None: + # Generate chain of tokens. + self.speculative_token_tree = str([[ + (i + 1) * (0, ) + for i in range(self.num_speculative_tokens) + ]]) + self.draft_tensor_parallel_size = \ SpeculativeConfig._verify_and_get_draft_tp( self.target_parallel_config, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6c908f88b9a..bc7c90e9911 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1427,7 +1427,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: recommend_to_remove=False) return False - # No XFormers so far. V1_BACKENDS = [ "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", @@ -1442,6 +1441,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: "ROCM_AITER_MLA", "TORCH_SDPA_VLLM_V1", "FLEX_ATTENTION", + "TREE_ATTN", ] if (envs.is_set("VLLM_ATTENTION_BACKEND") and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 879d094f657..2dbefb3c022 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -248,6 +248,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.info_once("Using Flash Attention backend on V1 engine.") return ("vllm.v1.attention.backends." "flash_attn.FlashAttentionBackend") + elif selected_backend == _Backend.TREE_ATTN: + logger.info_once("Using Tree Attention backend on V1 engine.") + return ("vllm.v1.attention.backends." + "tree_attn.TreeAttentionBackend") # Default backends for V1 engine # Prefer FlashInfer for Blackwell GPUs if installed diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 0f08bf98633..023dc6c78da 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -61,6 +61,7 @@ class _Backend(enum.Enum): DUAL_CHUNK_FLASH_ATTN = enum.auto() NO_ATTENTION = enum.auto() FLEX_ATTENTION = enum.auto() + TREE_ATTN = enum.auto() class PlatformEnum(enum.Enum): diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py new file mode 100644 index 00000000000..8f3b1834461 --- /dev/null +++ b/vllm/v1/attention/backends/tree_attn.py @@ -0,0 +1,525 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Attention layer with TreeAttention.""" + +import ast +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import torch +from xformers.ops.fmha import triton_splitk +from xformers.ops.fmha.attn_bias import (AttentionBias, + PagedBlockDiagonalPaddedKeysMask) +from xformers.ops.tree_attention import (_get_depth_counts, + _prepare_tree_attn_bias, + tree_attention) + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType) +from vllm.attention.ops.triton_unified_attention import unified_attention +from vllm.logger import init_logger +from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata) +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.block_table import BlockTable + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + +from vllm import _custom_ops as ops + +logger = init_logger(__name__) + + +class TreeAttentionBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "TREE_ATTN_VLLM_V1" + + @staticmethod + def get_impl_cls() -> type["TreeAttentionImpl"]: + return TreeAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return TreeAttentionMetadata + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_builder_cls() -> type["TreeAttentionMetadataBuilder"]: + return TreeAttentionMetadataBuilder + + +@dataclass +class TreeAttentionMetadata: + num_actual_tokens: int # Number of tokens excluding padding. + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + num_prefill_tokens: int = 0 + num_decode_tokens: int = 0 + num_prefills: int = 0 + num_decodes: int = 0 + + prefix_attn_bias: Optional[AttentionBias] = None + spec_attn_bias: Optional[torch.Tensor] = None + + # Cached Prefill/decode metadata. + _cached_prefill_metadata: Optional["TreeAttentionMetadata"] = None + _cached_decode_metadata: Optional["TreeAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["TreeAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + # Recover cached prefill-phase attention + # metadata structure + return self._cached_prefill_metadata + + q_start_loc = self.query_start_loc[self.num_decodes:] + q_seqlens = torch.diff(q_start_loc) + kv_seqlens = self.seq_lens[self.num_decodes:] + # Construct & cache prefill-phase attention metadata structure + self._cached_prefill_metadata = TreeAttentionMetadata( + num_actual_tokens=self.num_prefill_tokens, + max_query_len=int(q_seqlens.max().item()), + query_start_loc=q_start_loc - q_start_loc[0], + max_seq_len=int(kv_seqlens.max().item()), + seq_lens=kv_seqlens, + block_table=self.block_table[self.num_decodes:], + slot_mapping=self.slot_mapping[self.num_decode_tokens:], + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["TreeAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + # Recover cached decode-phase attention + # metadata structure + return self._cached_decode_metadata + + q_start_loc = self.query_start_loc[:self.num_decodes + 1] + q_seqlens = torch.diff(q_start_loc) + kv_seqlens = self.seq_lens[:self.num_decodes] + # Construct & cache decode-phase attention metadata structure + self._cached_decode_metadata = TreeAttentionMetadata( + num_actual_tokens=self.num_decode_tokens, + max_query_len=int(q_seqlens.max().item()), + query_start_loc=q_start_loc, + max_seq_len=int(kv_seqlens.max().item()), + seq_lens=kv_seqlens, + block_table=self.block_table[:self.num_decodes], + slot_mapping=self.slot_mapping[:self.num_decode_tokens], + prefix_attn_bias=self.prefix_attn_bias, + spec_attn_bias=self.spec_attn_bias, + ) + return self._cached_decode_metadata + + @staticmethod + def from_eagle_attn_metadata( + flash_attn_metadata: "FlashAttentionMetadata", + ) -> "TreeAttentionMetadata": + num_prefills = flash_attn_metadata.query_start_loc.shape[0] - 1 + return TreeAttentionMetadata( + num_actual_tokens=flash_attn_metadata.num_actual_tokens, + num_prefill_tokens=flash_attn_metadata.num_actual_tokens, + num_prefills=num_prefills, + max_query_len=flash_attn_metadata.max_query_len, + query_start_loc=flash_attn_metadata.query_start_loc, + max_seq_len=flash_attn_metadata.max_seq_len, + seq_lens=flash_attn_metadata.seq_lens, + block_table=flash_attn_metadata.block_table, + slot_mapping=flash_attn_metadata.slot_mapping, + ) + + +class TreeAttentionMetadataBuilder( + AttentionMetadataBuilder[TreeAttentionMetadata]): + + def __init__( + self, + runner: "GPUModelRunner", + kv_cache_spec: AttentionSpec, + block_table: BlockTable, + ): + self.runner = runner + self.kv_cache_spec = kv_cache_spec + self.block_table = block_table + self.block_size = kv_cache_spec.block_size + + spec_config = runner.vllm_config.speculative_config + spec_token_tree = spec_config.speculative_token_tree + tree_choices: list[tuple[int, + ...]] = (ast.literal_eval(spec_token_tree) if + spec_token_tree is not None else []) + # Construct the tree attention bias. + depth_counts = _get_depth_counts(tree_choices) + self.tree_attn_bias = _prepare_tree_attn_bias( + tree_choices, + depth_counts, + dtype=self.kv_cache_spec.dtype, + device=block_table.device, + ) + self.suffix_attn_bias = self.tree_attn_bias + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + # We now want to reorder the batch so that the "decode" requests are and + # the front and the "prefill" requests are at the using the least amount + # swaps possible. (NOTE for now we loosely use "decode" to mean requests + # where attention is likely memory-bound and "prefill" to mean requests + # where attention is likely compute-bound, TODO(lucas): figure out a + # better naming here) + decodes = [] + prefills = [] + num_decode_tokens = 0 + + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + # for now treat 1 scheduled token as "decode" even if its not, + # we should update this to something like < 8 in the future but + # currently the decode run only supports num_tokens = 1 + # For now, treat any decode step with exactly + if num_tokens == self.suffix_attn_bias.shape[0]: + decodes.append(i) + num_decode_tokens += num_tokens + else: + prefills.append(i) + + # We hope that this is fairly minimal since decodes + # should be around for a number of iterations so hopefully they are + # relatively stationary (and new request are generally appended to the + # persistent batch so already should be at the back) + # To achieve this we loop over the decodes in descending order and + # the prefills in ascending order. We swap decodes from the "back" + # i.e. past where the last decode should be in the reodorered with + # prefills from the front of the batch. + # `decodes` and `prefills` are already in ascending order just based on + # the above loop + num_decodes = len(decodes) + num_prefills = len(prefills) + modified_batch = False + + for i in range(1, min(num_decodes, num_prefills) + 1): + # If the decode is at the "back" of the batch, i, we can swap it + # with the prefill closest to the front of the batch + decode_idx = decodes[num_decodes - i] + if decode_idx < num_decodes: + break + + input_batch.swap_states(prefills[i - 1], decode_idx) + modified_batch = True + + # Save for next `build` call + # TODO(lucas): this is a bit of a hack, we should probably have a + # better way of doing this + self._num_decodes = num_decodes + self._num_decode_tokens = num_decode_tokens + + return modified_batch + + def build( + self, common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata + ) -> TreeAttentionMetadata: + num_reqs = common_attn_metadata.num_reqs + num_decodes = self._num_decodes + num_prefills = num_reqs - num_decodes + num_actual_tokens = common_attn_metadata.num_actual_tokens + num_decode_tokens = self._num_decode_tokens + num_prefill_tokens = num_actual_tokens - num_decode_tokens + q_start_loc = common_attn_metadata.query_start_loc + q_seqlens = torch.diff(q_start_loc) + max_query_len = common_attn_metadata.max_query_len + kv_seqlens = common_attn_metadata.seq_lens + max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) + block_table = self.block_table + slot_mapping = block_table.slot_mapping + + # Get the block table and slot mapping for paged KV. + block_table_tensor = block_table.get_device_tensor()[:num_reqs] + slot_mapping[:num_actual_tokens].copy_( + block_table.slot_mapping_cpu[:num_actual_tokens], + non_blocking=True, + ) + # Fill unused with -1. Needed for reshape_and_cache in full cuda graph + # mode. + slot_mapping[num_actual_tokens:].fill_(-1) + + prefix_attn_bias = None + if num_decodes > 0: + # Construct the prefix bias. + decode_q_seqlens = q_seqlens[:num_decodes] + decode_kv_seqlens = kv_seqlens[:num_decodes] + prefix_kv_seqlens = decode_kv_seqlens - decode_q_seqlens + prefix_attn_bias = PagedBlockDiagonalPaddedKeysMask.from_seqlens( + q_seqlen=decode_q_seqlens.tolist(), + kv_seqlen=prefix_kv_seqlens.tolist(), + page_size=self.block_size, + block_tables=block_table_tensor[:num_decodes], + device=block_table.device, + ) + + return TreeAttentionMetadata( + num_actual_tokens=num_actual_tokens, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + num_decodes=num_decodes, + max_query_len=max_query_len, + query_start_loc=q_start_loc, + max_seq_len=max_seq_len, + seq_lens=kv_seqlens, + block_table=block_table_tensor, + slot_mapping=slot_mapping, + prefix_attn_bias=prefix_attn_bias, + spec_attn_bias=self.suffix_attn_bias, + ) + + def build_for_drafting( + self, + common_attn_metadata: CommonAttentionMetadata, + tree_level_offset: int, + ) -> TreeAttentionMetadata: + orig_num_decodes = self._num_decodes + orig_num_decode_tokens = self._num_decode_tokens + # While drafting, all requests are treated as decodes. + self._num_decodes = common_attn_metadata.num_reqs + self._num_decode_tokens = common_attn_metadata.num_actual_tokens + + # Slice the suffix attention bias so that + query_len = common_attn_metadata.max_query_len + start, end = tree_level_offset, tree_level_offset + query_len + self.suffix_attn_bias = self.tree_attn_bias[start:end, start:end] + + # Build attention bias. + attn_metadata = self.build(0, common_attn_metadata) + + # Reset properties to original values. + self._num_decodes = orig_num_decodes + self._num_decode_tokens = orig_num_decode_tokens + self.suffix_attn_bias = self.tree_attn_bias + return attn_metadata + + +class TreeAttentionImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: AttentionType = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + if sliding_window is None: + self.sliding_window = (-1, -1) + else: + self.sliding_window = (sliding_window - 1, 0) + + support_head_sizes = TreeAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by TreeAttention. " + f"Supported head sizes are: {support_head_sizes}. " + "Set VLLM_USE_V1=0 to use another attention backend.") + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: TreeAttentionMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with TreeAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for TreeAttentionImpl") + + if attn_metadata is None: + # Profiling run. + return output + + # Cache the input KVs. + key_cache, value_cache = kv_cache.unbind(0) + if self.kv_sharing_target_layer_name is None: + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] + # and value[:num_actual_tokens] because the reshape_and_cache_flash + # op uses the slot_mapping's shape to determine the number of + # actual tokens. + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + num_actual_tokens = attn_metadata.num_actual_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + num_decodes = attn_metadata.num_decodes + if prefill_meta := attn_metadata.prefill_metadata: + descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, + key.shape[1]) + unified_attention( + q=query[num_decode_tokens:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[num_decode_tokens:num_actual_tokens], + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + seqused_k=prefill_meta.seq_lens, + max_seqlen_k=prefill_meta.max_seq_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=prefill_meta.block_table, + softcap=self.logits_soft_cap, + q_descale=None, # Not supported + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + + if decode_meta := attn_metadata.decode_metadata: + # Get only the speculatively decoded q, k, and vs. + spec_q = query[:num_decode_tokens] + spec_k = key[:num_decode_tokens] + spec_v = value[:num_decode_tokens] + + # Reshape q, k, and vs to [B, M, H, D]. + spec_q = spec_q.view(num_decodes, -1, self.num_heads, + self.head_size) + spec_k = spec_k.view(num_decodes, -1, self.num_kv_heads, + self.head_size) + spec_v = spec_v.view(num_decodes, -1, self.num_kv_heads, + self.head_size) + # Reshape the KV cache to [Bkv, Mk, H, D] + cache_k = key_cache.view(1, -1, self.num_kv_heads, self.head_size) + cache_v = value_cache.view(1, -1, self.num_kv_heads, + self.head_size) + + if self.num_kv_heads != self.num_heads: + # GQA/MQA. Reshape q, k, and v to [B, M, G, H, K]. + spec_q = spec_q.view( + spec_q.shape[0], + spec_q.shape[1], + self.num_kv_heads, + self.num_queries_per_kv, + spec_q.shape[-1], + ) + spec_k = spec_k[:, :, :, None, :].expand( + spec_k.shape[0], + spec_k.shape[1], + self.num_kv_heads, + self.num_queries_per_kv, + spec_k.shape[-1], + ) + spec_v = spec_v[:, :, :, None, :].expand( + spec_v.shape[0], + spec_v.shape[1], + self.num_kv_heads, + self.num_queries_per_kv, + spec_v.shape[-1], + ) + # Reshape the KV cache to [Bkv, Mk, G, H, K] + cache_k = cache_k[:, :, :, None, :].expand( + cache_k.shape[0], + cache_k.shape[1], + self.num_kv_heads, + self.num_queries_per_kv, + cache_k.shape[-1], + ) + cache_v = cache_v[:, :, :, None, :].expand( + cache_v.shape[0], + cache_v.shape[1], + self.num_kv_heads, + self.num_queries_per_kv, + cache_v.shape[-1], + ) + + # Perform tree attention on the speculatively decoded tokens. + output[:num_decode_tokens] = tree_attention( + q=spec_q, + spec_k=spec_k, + spec_v=spec_v, + cache_k=cache_k, + cache_v=cache_v, + prefix_op=triton_splitk.FwOp, + suffix_op=triton_splitk.FwOp, + prefix_attn_bias=decode_meta.prefix_attn_bias, + spec_attn_bias=decode_meta.spec_attn_bias, + ).view(-1, self.num_heads, self.head_size) + return output diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 8083f200260..7bbf0edb80f 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -77,6 +77,17 @@ def build_for_cudagraph_capture( return self.build(common_prefix_len=0, common_attn_metadata=common_attn_metadata) + def build_for_drafting( + self, + common_attn_metadata: CommonAttentionMetadata, + tree_level_offset: int, + ) -> M: + """ + Build attention metadata for draft model. Uses build by default. + """ + return self.build(common_prefix_len=0, + common_attn_metadata=common_attn_metadata) + def use_cascade_attention( self, common_prefix_len: int, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 156f5764e8d..89ad1d6d01c 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ast + import torch import torch.nn as nn @@ -14,6 +16,8 @@ from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata, FlashAttentionMetadata) +from vllm.v1.attention.backends.tree_attn import (TreeAttentionBackend, + TreeAttentionMetadata) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel @@ -41,10 +45,8 @@ def __init__( self.dtype = vllm_config.model_config.dtype self.max_model_len = vllm_config.model_config.max_model_len self.block_size = vllm_config.cache_config.block_size - self.num_speculative_tokens = ( - self.speculative_config.num_speculative_tokens) - self.max_num_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens) + self.num_speculative_tokens = self.speculative_config.num_speculative_tokens + self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens # We need to get the hidden size from the draft model config because # the draft model's hidden size can be different from the target model's # hidden size (e.g., Llama 3.3 70B). @@ -68,12 +70,52 @@ def __init__( (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device) - # We need +1 here because the arange is used to set query_start_loc, - # which has one more element than batch_size. - self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + - 1, - device=device, - dtype=torch.int32) + + max_batch_size = vllm_config.scheduler_config.max_num_seqs + self.arange = torch.arange( + # We need +1 here because the arange is used to set query_start_loc, + # which has one more element than batch_size. + max_batch_size + 1, + device=device, + dtype=torch.int32, + ) + + # Parse the speculative token tree. + spec_token_tree = self.speculative_config.speculative_token_tree + self.tree_choices: list[tuple[int, + ...]] = ast.literal_eval(spec_token_tree) + tree_depth = len(self.tree_choices[-1]) + # Precompute per-level properties of the tree. + num_drafts_per_level = [0] * tree_depth + for node in self.tree_choices: + num_drafts_per_level[len(node) - 1] += 1 + self.cu_drafts_per_level = [num_drafts_per_level[0]] + self.child_drafts_per_level = [num_drafts_per_level[0]] + for level in range(1, tree_depth): + self.cu_drafts_per_level.append(self.cu_drafts_per_level[-1] + + num_drafts_per_level[level]) + self.child_drafts_per_level.append(num_drafts_per_level[level] // + num_drafts_per_level[level - 1]) + # Find the first level where the tree branches off into one or more + # children. + self.first_branching_level = None + for level in range(tree_depth): + if self.cu_drafts_per_level[level] > level + 1: + self.first_branching_level = level + break + # Precompute draft position offsets in flattened tree. + self.tree_draft_pos_offsets = torch.arange( + 1, + len(self.tree_choices) + 1, + device=device, + dtype=torch.int32, + ).repeat(max_batch_size, 1) + + def get_padded_num_input_tokens(self, num_tokens: int): + return (self.vllm_config.pad_for_cudagraph(num_tokens) + if self.use_cuda_graph + and num_tokens <= self.cudagraph_batch_sizes[-1] else + num_tokens) def propose( self, @@ -133,6 +175,11 @@ def propose( prefix_kv_lens=None, suffix_kv_lens=None, ) + + if (self.runner is not None + and self.runner.attn_backends[0] == TreeAttentionBackend): + attn_metadata = TreeAttentionMetadata.from_eagle_attn_metadata( + attn_metadata) elif self.method == "deepseek_mtp": query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] max_query_len = query_lens.max().item() @@ -160,11 +207,7 @@ def propose( per_layer_attn_metadata = {} for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata - if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: - num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) - else: - num_input_tokens = num_tokens + num_input_tokens = self.get_padded_num_input_tokens(num_tokens) # copy inputs to buffer for cudagraph self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states @@ -183,10 +226,10 @@ def propose( last_hidden_states, hidden_states = ret_hidden_states sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) - draft_token_ids = logits.argmax(dim=-1) # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1: + draft_token_ids = logits.argmax(dim=-1) # [batch_size, 1] return draft_token_ids.view(-1, 1) @@ -194,41 +237,98 @@ def propose( # one layer. Adapt this code to support multiple layers once # there's a multi-layer MTP module. - # Generate the remaining draft tokens. + # Sample a draft token for each tree child at the root level. + draft_token_ids = torch.topk(logits, + self.child_drafts_per_level[0], + dim=-1).indices + + # Prepare to generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] + draft_last_hidden_states = hidden_states[last_token_indices].view( + batch_size, 1, -1) - positions = target_positions[last_token_indices] - hidden_states = hidden_states[last_token_indices] - if self.use_cuda_graph and \ - batch_size <= self.cudagraph_batch_sizes[-1]: - input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) - else: - input_batch_size = batch_size + # Setup attention metadata for drafting. attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange[:batch_size + 1] - for _ in range(self.num_speculative_tokens - 1): - # Update the inputs. - # cast to int32 is crucial when eagle model is compiled. - # tensor.argmax() returns int64 by default. - input_ids = draft_token_ids_list[-1].int() - positions += 1 - - # NOTE(woosuk): We should handle the case where the draft model - # generates tokens beyond the max model length. Since it is complex - # to remove such requests from the batch, we keep them in the batch - # but adjust the position ids and slot mappings to avoid the - # out-of-range access during the model execution. The draft tokens - # generated with this adjustment should be ignored. - exceeds_max_model_len = positions >= self.max_model_len + + # Initialize empty tensors for concatenation with the level outputs. + input_ids = torch.empty(0, + device=self.input_ids.device, + dtype=self.input_ids.dtype) + positions = torch.empty(0, + device=self.positions.device, + dtype=self.positions.dtype) + hidden_states = torch.empty(0, + device=self.hidden_states.device, + dtype=self.hidden_states.dtype) + # Precompute the draft token positions. + base_positions = target_positions[last_token_indices] + flattened_draft_positions = ( + base_positions.view(batch_size, -1) + + self.tree_draft_pos_offsets[:batch_size, :]) + total_num_drafts = 0 + tree_depth = len(self.cu_drafts_per_level) + for level in range(tree_depth - 1): + num_level_drafts = self.cu_drafts_per_level[ + level] - total_num_drafts + total_num_drafts = self.cu_drafts_per_level[level] + + # Get draft positions for RoPE. + draft_positions = base_positions + (level + 1) + exceeds_max_model_len = (base_positions + + total_num_drafts) >= self.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. - clamped_positions = torch.where(exceeds_max_model_len, 0, - positions) + clamped_draft_positions = torch.where( + exceeds_max_model_len, + 0, + draft_positions, + ) + draft_positions = clamped_draft_positions.repeat_interleave( + num_level_drafts).reshape(batch_size, -1) + + # Broadcast draft hidden states for each child. + draft_hidden_states = draft_last_hidden_states.repeat_interleave( + self.child_drafts_per_level[level], dim=1) + + if (self.first_branching_level is not None + and level >= self.first_branching_level): + # Draft branching has occurred. Tree attention must be used to + # predict subsequent draft tokens. + query_len = total_num_drafts - self.first_branching_level + input_ids = torch.cat([input_ids, draft_token_ids], dim=1) + positions = torch.cat([positions, draft_positions], dim=1) + hidden_states = torch.cat([hidden_states, draft_hidden_states], + dim=1) + + # Build new attention metadata for the next level of drafts. + # This is necessary to support tree attention. + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=query_len * self.arange[:batch_size + 1], + seq_lens=attn_metadata.seq_lens + num_level_drafts, + num_reqs=batch_size, + num_actual_tokens=batch_size * query_len, + max_query_len=query_len, + ) + attn_metadata = self.runner.attn_metadata_builders[ + 0].build_for_drafting( + common_attn_metadata=common_attn_metadata, + tree_level_offset=self.first_branching_level, + ) + + # Apply new attention metadata to all layers. + per_layer_attn_metadata = {} + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata + else: + # Still a chain of drafts. Continue performing standard attention, + # which is more efficient than tree attention. + query_len = 1 + input_ids = draft_token_ids + positions = draft_positions + hidden_states = draft_hidden_states - # Increment the sequence lengths. - attn_metadata.max_seq_len += 1 - attn_metadata.seq_lens += 1 # Consider max model length. attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, self.max_model_len) @@ -237,43 +337,53 @@ def propose( attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) # Compute the slot mapping. - block_numbers = clamped_positions // self.block_size - block_ids = block_table.gather(dim=1, - index=block_numbers.view(-1, 1)) - block_ids = block_ids.view(-1) - attn_metadata.slot_mapping = (block_ids * self.block_size + - clamped_positions % self.block_size) + query_positions = flattened_draft_positions[:, level:level + + query_len] + block_numbers = query_positions // self.block_size + block_ids = block_table.gather(dim=1, index=block_numbers) + slot_mapping = (block_ids * self.block_size + + query_positions % self.block_size) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. - attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len, - PADDING_SLOT_ID) + slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID + attn_metadata.slot_mapping = slot_mapping.view(-1) - # copy inputs to buffer for cudagraph - self.input_ids[:batch_size] = input_ids - self.positions[:batch_size] = clamped_positions - self.hidden_states[:batch_size] = hidden_states + # Copy inputs to buffer for cudagraph. + num_tokens = attn_metadata.num_actual_tokens + self.input_ids[:num_tokens] = input_ids.view(-1) + self.positions[:num_tokens] = positions.view(-1) + self.hidden_states[:num_tokens] = hidden_states.view( + num_tokens, -1) + num_input_tokens = self.get_padded_num_input_tokens(num_tokens) # Run the model. with set_forward_context(per_layer_attn_metadata, self.vllm_config, - num_tokens=input_batch_size): - last_hidden_states, hidden_states = self.model( - self.input_ids[:input_batch_size], - self.positions[:input_batch_size], - self.hidden_states[:input_batch_size], + num_tokens=num_input_tokens): + last_hidden_states, _ = self.model( + self.input_ids[:num_input_tokens], + self.positions[:num_input_tokens], + self.hidden_states[:num_input_tokens], ) - hidden_states = hidden_states[:batch_size] - logits = self.model.compute_logits(last_hidden_states[:batch_size], - None) - # TODO(wenlong): get more than one token for tree attention - draft_token_ids = logits.argmax(dim=-1) + # Get the last hidden states for predicting the drafts. + draft_last_hidden_states = last_hidden_states[:num_tokens].view( + batch_size, query_len, -1)[:, -num_level_drafts:] + + # Sample a draft token for each child at the current tree level. + logits = self.model.compute_logits( + draft_last_hidden_states.reshape(batch_size * num_level_drafts, + -1), None) + draft_token_ids = torch.topk(logits, + self.child_drafts_per_level[level + + 1], + dim=-1).indices.view(batch_size, -1) draft_token_ids_list.append(draft_token_ids) # [batch_size, num_speculative_tokens] - draft_token_ids = torch.stack(draft_token_ids_list, dim=1) - return draft_token_ids + output = torch.cat(draft_token_ids_list, dim=1) + return torch.cat(draft_token_ids_list, dim=1) @staticmethod def prepare_inputs( @@ -292,8 +402,7 @@ def prepare_inputs( # a + b, a + b + 1, ..., a + b + c - n3 - 1] # [0, a, a + b, a + b + c] -> [a, b, c] - query_len_per_req = (cu_target_query_lens[1:] - - cu_target_query_lens[:-1]) + query_len_per_req = cu_target_query_lens[1:] - cu_target_query_lens[:-1] # [a, b, c] -> [a - n1, b - n2, c - n3] num_tokens_per_req = query_len_per_req - num_rejected_tokens @@ -317,12 +426,12 @@ def prepare_inputs( return cu_num_tokens, token_indices def load_model(self, target_model: nn.Module) -> None: - draft_model_config = \ - self.vllm_config.speculative_config.draft_model_config + draft_model_config = self.vllm_config.speculative_config.draft_model_config target_attn_layer_names = set( get_layers_from_vllm_config(self.vllm_config, Attention).keys()) from vllm.compilation.backends import set_model_tag + with set_model_tag("eagle_head"): self.model = get_model(vllm_config=self.vllm_config, model_config=draft_model_config) @@ -335,33 +444,29 @@ def load_model(self, target_model: nn.Module) -> None: if supports_multimodal(target_model): # handle multimodality - self.model.config.image_token_index = ( - target_model.config.image_token_index) + self.model.config.image_token_index = target_model.config.image_token_index target_language_model = target_model.get_language_model() else: target_language_model = target_model # share embed_tokens with the target model if needed - if get_pp_group().world_size == 1 \ - and self.model.model.embed_tokens.weight.shape \ - == target_language_model.model.embed_tokens.weight.shape: + if (get_pp_group().world_size == 1 + and self.model.model.embed_tokens.weight.shape + == target_language_model.model.embed_tokens.weight.shape): logger.info( - "Assuming the EAGLE head shares the same vocab embedding" \ - " with the target model." - ) + "Assuming the EAGLE head shares the same vocab embedding" + " with the target model.") del self.model.model.embed_tokens - self.model.model.embed_tokens = ( - target_language_model.model.embed_tokens) + self.model.model.embed_tokens = target_language_model.model.embed_tokens else: logger.info( - "The EAGLE head's vocab embedding will be loaded separately" \ - " from the target model." - ) + "The EAGLE head's vocab embedding will be loaded separately" + " from the target model.") # share lm_head with the target model if needed # some model definition do not define lm_head explicitly # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM - if self.vllm_config.speculative_config.method != "eagle3" and \ - hasattr(target_language_model, "lm_head"): + if self.vllm_config.speculative_config.method != "eagle3" and hasattr( + target_language_model, "lm_head"): logger.info("Loading EAGLE LM head weights from the target model.") self.model.lm_head = target_language_model.lm_head @@ -390,12 +495,12 @@ def validate_same_kv_cache_group(self, for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): for layer_name in kv_cache_group.layer_names: kv_cache_groups[layer_name] = id - assert len( + assert (len( set([ kv_cache_groups[layer_name] for layer_name in self.attn_layer_names - ]) - ) == 1, "All eagle layers should belong to the same kv cache group" + ])) == 1 + ), "All eagle layers should belong to the same kv cache group" # NOTE(woosuk): Currently, the below code is not used and we always use argmax