From 66926e31ee517d3180c3b97ca4a22763b2f97e29 Mon Sep 17 00:00:00 2001 From: Giancarlo Delfin Date: Sat, 21 Jun 2025 14:22:28 -0700 Subject: [PATCH 1/4] [spec decoding] add tree attention backend for v1 Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Giancarlo Delfin --- vllm/v1/attention/backends/tree_attn.py | 526 ++++++++++++++++++++++++ 1 file changed, 526 insertions(+) create mode 100644 vllm/v1/attention/backends/tree_attn.py diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py new file mode 100644 index 00000000000..76ee7d2e765 --- /dev/null +++ b/vllm/v1/attention/backends/tree_attn.py @@ -0,0 +1,526 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Attention layer with TreeAttention.""" + +import ast +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import torch +from xformers.ops.fmha import triton_splitk +from xformers.ops.fmha.attn_bias import (AttentionBias, + PagedBlockDiagonalPaddedKeysMask) +from xformers.ops.tree_attention import tree_attention + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType) +from vllm.logger import init_logger +from vllm.v1.attention.backends.flash_attn import ( + FlashAttentionImpl, FlashAttentionMetadata, FlashAttentionMetadataBuilder) +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata) +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.block_table import BlockTable + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + +logger = init_logger(__name__) + + +class TreeAttentionBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "TREE_ATTN_VLLM_V1" + + @staticmethod + def get_impl_cls() -> type["TreeAttentionImpl"]: + return TreeAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return TreeAttentionMetadata + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_builder_cls() -> type["TreeAttentionMetadataBuilder"]: + return TreeAttentionMetadataBuilder + + +@dataclass +class TreeAttentionMetadata: + num_actual_tokens: int # Number of tokens excluding padding. + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + prefix_attn_bias: Optional[AttentionBias] + spec_attn_bias: Optional[torch.Tensor] + + # Attention metadata for prefill. + prefill_attn_metadata: Optional[FlashAttentionMetadata] + + +class TreeAttentionMetadataBuilder( + AttentionMetadataBuilder[TreeAttentionMetadata]): + + def __init__( + self, + runner: "GPUModelRunner", + kv_cache_spec: AttentionSpec, + block_table: BlockTable, + ): + self.runner = runner + self.kv_cache_spec = kv_cache_spec + self.block_table = block_table + self.block_size = kv_cache_spec.block_size + + spec_config = runner.vllm_config.speculative_config + spec_token_tree = spec_config.speculative_token_tree + self.tree_choices: list[tuple[int, ...]] = ( + ast.literal_eval(spec_token_tree) + if spec_token_tree is not None else []) + self.tree_size = len(self.tree_choices) + 1 + + self.prefill_attn_metadata_builder: FlashAttentionMetadataBuilder = ( + FlashAttentionMetadataBuilder( + runner, + kv_cache_spec, + block_table, + )) + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + # We now want to reorder the batch so that the "decode" requests are and + # the front and the "prefill" requests are at the using the least amount + # swaps possible. (NOTE for now we loosely use "decode" to mean requests + # where attention is likely memory-bound and "prefill" to mean requests + # where attention is likely compute-bound, TODO(lucas): figure out a + # better naming here) + decodes = [] + prefills = [] + num_decode_tokens = 0 + num_prefill_tokens = 0 + + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + # for now treat 1 scheduled token as "decode" even if its not, + # we should update this to something like < 8 in the future but + # currently the decode run only supports num_tokens = 1 + # For now, treat any decode step with exactly + if num_tokens == self.tree_size: + decodes.append(i) + num_decode_tokens += num_tokens + else: + prefills.append(i) + num_prefill_tokens += num_tokens + + # We hope that this is fairly minimal since decodes + # should be around for a number of iterations so hopefully they are + # relatively stationary (and new request are generally appended to the + # persistent batch so already should be at the back) + # To achieve this we loop over the decodes in descending order and + # the prefills in ascending order. We swap decodes from the "back" + # i.e. past where the last decode should be in the reodorered with + # prefills from the front of the batch. + # `decodes` and `prefills` are already in ascending order just based on + # the above loop + num_decodes = len(decodes) + num_prefills = len(prefills) + modified_batch = False + + for i in range(1, min(num_decodes, num_prefills) + 1): + # If the decode is at the "back" of the batch, i, we can swap it + # with the prefill closest to the front of the batch + decode_idx = decodes[num_decodes - i] + if decode_idx < num_decodes: + break + + input_batch.swap_states(prefills[i - 1], decode_idx) + modified_batch = True + + # Save for next `build` call + # TODO(lucas): this is a bit of a hack, we should probably have a + # better way of doing this + self._num_decodes = num_decodes + self._num_prefills = num_prefills + self._num_decode_tokens = num_decode_tokens + self._num_prefill_tokens = num_prefill_tokens + + return modified_batch + + def build( + self, common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata + ) -> TreeAttentionMetadata: + num_reqs = common_attn_metadata.num_reqs + num_decodes = self._num_decodes + num_prefills = self._num_prefills + num_actual_tokens = common_attn_metadata.num_actual_tokens + num_decode_tokens = self._num_decode_tokens + num_prefill_tokens = self._num_prefill_tokens + q_start_loc = common_attn_metadata.query_start_loc + q_seqlens = torch.diff(q_start_loc) + max_query_len = common_attn_metadata.max_query_len + kv_seqlens = common_attn_metadata.seq_lens + max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) + block_table = self.block_table + slot_mapping = block_table.slot_mapping + + # If there are any prefill requests, construct the prefill + # attention metadata. + prefill_attn_metadata = None + if num_prefills > 0: + # Temporarily set the block table slot mapping tensor to the + # slice for prefill. + block_table.slot_mapping = slot_mapping[num_decode_tokens:] + # Build prefill attention metadata. + prefill_attn_metadata = self.prefill_attn_metadata_builder.build( + common_prefix_len, + CommonAttentionMetadata( + query_start_loc=q_start_loc[num_decodes:] - + q_start_loc[num_decodes], + seq_lens=kv_seqlens[num_decodes:], + num_reqs=num_prefills, + num_actual_tokens=num_prefill_tokens, + max_query_len=int(q_seqlens[num_decodes:].max().item()), + ), + ) + # Restore block table slot mapping to the original, full tensor. + block_table.slot_mapping = slot_mapping + + # Get the block table and slot mapping for paged KV. + block_table_tensor = block_table.get_device_tensor()[:num_reqs] + slot_mapping[:num_decode_tokens].copy_( + block_table.slot_mapping_cpu[:num_decode_tokens], + non_blocking=True, + ) + # Fill unused with -1. Needed for reshape_and_cache in full cuda graph + # mode. + slot_mapping[num_actual_tokens:].fill_(-1) + + prefix_attn_bias = None + spec_attn_bias = None + if num_decodes > 0: + # Construct the prefix bias. + decode_q_seqlens = q_seqlens[:num_decodes] + decode_kv_seqlens = kv_seqlens[:num_decodes] + prefix_kv_seqlens = decode_kv_seqlens - decode_q_seqlens + prefix_attn_bias = PagedBlockDiagonalPaddedKeysMask.from_seqlens( + q_seqlen=decode_q_seqlens.tolist(), + kv_seqlen=prefix_kv_seqlens.tolist(), + page_size=self.block_size, + block_tables=block_table_tensor[:num_decodes], + device=block_table.device, + ) + # Construct the tree attention (suffix) bias. + spec_attn_bias = _prepare_tree_attn_bias( + self.tree_choices, + self.kv_cache_spec.dtype, + device=block_table.device, + ).T + + return TreeAttentionMetadata( + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + query_start_loc=q_start_loc, + max_seq_len=max_seq_len, + seq_lens=kv_seqlens, + block_table=block_table_tensor, + slot_mapping=slot_mapping, + prefix_attn_bias=prefix_attn_bias, + spec_attn_bias=spec_attn_bias, + prefill_attn_metadata=prefill_attn_metadata, + ) + + +def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]: + # Initialize depth_counts to keep track of how many choices have a + # particular depth. + depth_counts = [] + prev_depth = 0 + for path in sorted_tree_choices: + depth = len(path) + if depth != prev_depth: + depth_counts.append(0) + depth_counts[depth - 1] += 1 + prev_depth = depth + return depth_counts + + +def _prepare_tree_attn_bias( + sorted_tree_choices: list[tuple[int, ...]], + dtype: Optional[torch.dtype], + device: Optional[torch.device], +) -> torch.Tensor: + """ + Construct a Medusa-style tree attention bias as an explicit tensor. + It can be used as a spec_attn_bias ("right" or "suffix" attention part) + in tree_attention. See run_tree_attention_inner in test for a usage example. + Args: + sorted_tree_choices: tree description in the style of + https://github.com/FasterDecoding/Medusa/blob/5e9805386/medusa/model/medusa_choices.py + A typical tree description would look like: + [(node0, node1, ...), + (node0, node2), + (node0, node3), + (node1, node3), ..., + (node0, node2, ..., nodeN)] + Every tuple is corresponds to one node in the tree, encoded as a + path from one of the root nodes to the node in question. Passed + in sorted order. + + For example, a node encoded as (1, 0, 3, ..., 2) is understood as: + list all the root nodes and take node number 1 + list all children of that node and take node number 0 + list all children of that node and take node number 3 + ... + list all children of that node and take node number 2 - that's the + node encoded by this tuple + dtype: data type of the output tensor. + device: device of the output tensor. + Returns: + attention bias of shape (tree_size, tree_size), + where tree_size is the total number of nodes in the tree. + """ + depth_counts = _get_depth_counts(sorted_tree_choices) + + # +1 comes from the additional root node + tree_len = len(sorted_tree_choices) + 1 + tree_attn_mask = torch.full((tree_len, tree_len), + -torch.inf, + device=device, + dtype=dtype) + + mask_val = 0 + for i in range(tree_len): + tree_attn_mask[i, i] = mask_val + + tree_attn_mask[:, 0] = mask_val + start = 0 + for i in range(len(depth_counts)): + for j in range(depth_counts[i]): + cur_tree_choice = sorted_tree_choices[start + j] + # retrieve ancestor position + if len(cur_tree_choice) == 1: + continue + ancestor_idx = [] + for c in range(len(cur_tree_choice) - 1): + ancestor_idx.append( + sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1) + tree_attn_mask[j + start + 1, ancestor_idx] = mask_val + start += depth_counts[i] + return tree_attn_mask + + +class TreeAttentionImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: AttentionType = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + + support_head_sizes = TreeAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by TreeAttention. " + f"Supported head sizes are: {support_head_sizes}. " + "Set VLLM_USE_V1=0 to use another attention backend.") + + self.prefill_attention_impl = FlashAttentionImpl( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=alibi_slopes, + sliding_window=sliding_window, + kv_cache_dtype=kv_cache_dtype, + blocksparse_params=blocksparse_params, + logits_soft_cap=logits_soft_cap, + attn_type=attn_type, + kv_sharing_target_layer_name= + None, # Skip KV reshape and cache. This class handles it. + use_irope=use_irope, + ) + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: TreeAttentionMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with TreeAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for TreeAttentionImpl") + + if attn_metadata is None: + # Profiling run. + return output + + # Cache the input KVs. + key_cache, value_cache = kv_cache.unbind(0) + if self.kv_sharing_target_layer_name is None: + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] + # and value[:num_actual_tokens] because the reshape_and_cache_flash + # op uses the slot_mapping's shape to determine the number of + # actual tokens. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + num_decode_tokens = attn_metadata.num_actual_tokens + num_decodes = attn_metadata.query_start_loc.shape[0] - 1 + prefill_attn_metadata = attn_metadata.prefill_attn_metadata + if prefill_attn_metadata is not None: + num_decode_tokens -= prefill_attn_metadata.num_actual_tokens + num_decodes -= prefill_attn_metadata.query_start_loc.shape[0] - 1 + # Perform prefill flash attention. + self.prefill_attention_impl.forward( + layer, + query[num_decode_tokens:], + key[num_decode_tokens:], + value[num_decode_tokens:], + kv_cache, + prefill_attn_metadata, + output[num_decode_tokens:], + None, + ) + + if num_decodes == 0: + # No decode requests, abort early. + return output + + # Get only the speculatively decoded q, k, and vs. + spec_q = query[:num_decode_tokens] + spec_k = key[:num_decode_tokens] + spec_v = value[:num_decode_tokens] + + # Reshape q, k, and vs to [B, M, H, D]. + spec_q = spec_q.view(num_decodes, -1, self.num_heads, self.head_size) + spec_k = spec_k.view(num_decodes, -1, self.num_kv_heads, + self.head_size) + spec_v = spec_v.view(num_decodes, -1, self.num_kv_heads, + self.head_size) + # Reshape the KV cache to [Bkv, Mk, H, D] + cache_k = key_cache.view(1, -1, self.num_kv_heads, self.head_size) + cache_v = value_cache.view(1, -1, self.num_kv_heads, self.head_size) + + if self.num_kv_heads != self.num_heads: + # GQA/MQA. Reshape q, k, and v to [B, M, G, H, K]. + spec_q = spec_q.view( + spec_q.shape[0], + spec_q.shape[1], + self.num_kv_heads, + self.num_queries_per_kv, + spec_q.shape[-1], + ) + spec_k = spec_k[:, :, :, None, :].expand( + spec_k.shape[0], + spec_k.shape[1], + self.num_kv_heads, + self.num_queries_per_kv, + spec_k.shape[-1], + ) + spec_v = spec_v[:, :, :, None, :].expand( + spec_v.shape[0], + spec_v.shape[1], + self.num_kv_heads, + self.num_queries_per_kv, + spec_v.shape[-1], + ) + # Reshape the KV cache to [Bkv, Mk, G, H, K] + cache_k = cache_k[:, :, :, None, :].expand( + cache_k.shape[0], + cache_k.shape[1], + self.num_kv_heads, + self.num_queries_per_kv, + cache_k.shape[-1], + ) + cache_v = cache_v[:, :, :, None, :].expand( + cache_v.shape[0], + cache_v.shape[1], + self.num_kv_heads, + self.num_queries_per_kv, + cache_v.shape[-1], + ) + + # Perform tree attention on the speculatively decoded tokens. + output[:num_decode_tokens] = tree_attention( + q=spec_q, + spec_k=spec_k, + spec_v=spec_v, + cache_k=cache_k, + cache_v=cache_v, + prefix_op=triton_splitk.FwOp, + suffix_op=triton_splitk.FwOp, + prefix_attn_bias=attn_metadata.prefix_attn_bias, + spec_attn_bias=attn_metadata.spec_attn_bias, + ).view(-1, self.num_heads, self.head_size) + return output From 0efbc4968d6137f415af14e40641520ae79be6e1 Mon Sep 17 00:00:00 2001 From: Giancarlo Delfin Date: Thu, 26 Jun 2025 16:51:47 -0700 Subject: [PATCH 2/4] [spec decoding] add test for tree attention correctness Signed-off-by: Giancarlo Delfin --- tests/spec_decode/test_tree_attention.py | 209 +++++++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 tests/spec_decode/test_tree_attention.py diff --git a/tests/spec_decode/test_tree_attention.py b/tests/spec_decode/test_tree_attention.py new file mode 100644 index 00000000000..94b4d9da126 --- /dev/null +++ b/tests/spec_decode/test_tree_attention.py @@ -0,0 +1,209 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math + +import torch +from xformers.ops.fmha.attn_bias import PagedBlockDiagonalPaddedKeysMask + +from vllm.attention.backends.abstract import AttentionBackend +from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend +from vllm.v1.attention.backends.tree_attn import TreeAttentionBackend + + +class NoOpLayerModule(torch.nn.Module): + _q_scale = torch.tensor(1.0, dtype=torch.float32) + _k_scale = torch.tensor(1.0, dtype=torch.float32) + _v_scale = torch.tensor(1.0, dtype=torch.float32) + + def __init__(self): + super().__init__() + + def forward(self, x): + return x + + +def forward_attention( + batch_size: int, + num_heads: int, + num_kv_heads: int, + dim_per_head: int, + block_size: int, + max_sequence_length: int, + sequence_position: int, + q_len: int, + backends: list[type[AttentionBackend]], +) -> list[torch.Tensor]: + # Assert that the number of heads is divisible by the number of KV heads. + assert num_heads % num_kv_heads == 0 + + device = "cuda" + # Initialize q, k, and v. + q = torch.randn( + (batch_size * q_len, num_heads, dim_per_head), + device=device, + dtype=torch.bfloat16, + ) + k = torch.randn( + (batch_size * q_len, num_kv_heads, dim_per_head), + device=device, + dtype=torch.bfloat16, + ) + v = torch.randn( + (batch_size * q_len, num_kv_heads, dim_per_head), + device=device, + dtype=torch.bfloat16, + ) + + # Initialize the query and KV sequence lengths. + cu_seqlens_q = q_len * torch.arange( + batch_size + 1, device=device, dtype=torch.int32) + seqlens_q = torch.diff(cu_seqlens_q) + seqlens_kv = torch.full( + (batch_size, ), + sequence_position + q_len, + device=device, + dtype=torch.int32, + ) + max_seqlen_q = q_len + max_seqlen_k = sequence_position + q_len + num_actual_tokens = cu_seqlens_q[-1] + + # Setup the block table and KV cache for paged KV. + assert max_sequence_length % block_size == 0 + max_block_count_per_batch = max_sequence_length // block_size + kv_cache = torch.randn( + ( + 2, + batch_size * max_block_count_per_batch, + block_size, + num_kv_heads, + dim_per_head, + ), + device=device, + dtype=torch.bfloat16, + ) + num_allocated_blocks_per_batch = math.ceil(max_seqlen_k / block_size) + block_table = torch.zeros( + (batch_size, max_block_count_per_batch), + device=device, + dtype=torch.int32, + ) + block_ids = torch.arange( + 0, + batch_size * num_allocated_blocks_per_batch, + device=device, + dtype=torch.int32, + ).view(-1, num_allocated_blocks_per_batch) + block_table[:, :num_allocated_blocks_per_batch] = block_ids + + # Setup the slot mapping for the input KVs. + slots_per_batch = [] + for i in range(batch_size): + start_offset = block_ids[i, 0] * block_size + sequence_position + slots_per_batch.append( + torch.arange( + start_offset, + start_offset + q_len, + device=device, + dtype=torch.int64, + )) + slot_mapping = torch.cat(slots_per_batch, dim=0) + + softmax_scale = q.shape[-1]**(-0.5) + layer = NoOpLayerModule() + + # Run attention for each backend and collect the outputs. + outputs = [] + for backend_cls in backends: + # Set common metadata. + attn_metadata_dict = { + "num_actual_tokens": num_actual_tokens, + "max_query_len": max_seqlen_q, + "query_start_loc": cu_seqlens_q, + "max_seq_len": max_seqlen_k, + "seq_lens": seqlens_kv, + "block_table": block_table, + "slot_mapping": slot_mapping, + } + + # Set backend-specific metadata. + if backend_cls == FlashAttentionBackend: + attn_metadata_dict["use_cascade"] = False + attn_metadata_dict["common_prefix_len"] = 0 + attn_metadata_dict["cu_prefix_query_lens"] = None + attn_metadata_dict["prefix_kv_lens"] = None + attn_metadata_dict["suffix_kv_lens"] = None + elif backend_cls == TreeAttentionBackend: + # Construct the prefix bias. + prefix_kv_seqlens = seqlens_kv - seqlens_q + prefix_attn_bias = PagedBlockDiagonalPaddedKeysMask.from_seqlens( + q_seqlen=seqlens_q.tolist(), + kv_seqlen=prefix_kv_seqlens.tolist(), + page_size=block_size, + block_tables=block_table, + device=device, + ) + attn_metadata_dict["prefix_attn_bias"] = prefix_attn_bias + # Create a chain attn bias. + chain_attn_bias = torch.triu( + torch.full((q_len, q_len), + float("-inf"), + device=device, + dtype=torch.bfloat16), + diagonal=1, + ) + attn_metadata_dict["spec_attn_bias"] = chain_attn_bias + attn_metadata_dict["prefill_attn_metadata"] = None + + # Initialize the backend implementation. + instance = backend_cls.get_impl_cls()( + num_heads=num_heads, + head_size=dim_per_head, + scale=softmax_scale, + num_kv_heads=num_kv_heads, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="auto", + ) + + # Run forward pass and store output. + output = torch.empty_like(q) + outputs.append( + instance.forward( + layer=layer, + query=q, + key=k, + value=v, + kv_cache=kv_cache.clone(), + attn_metadata=backend_cls.get_metadata_cls()( + **attn_metadata_dict), + output=output, + )) + return outputs + + +def test_tree_attn_correctness() -> None: + torch.cuda.manual_seed_all(0) + + for batch_size in [1, 2, 16, 32, 64]: + for num_heads in [2, 4]: + for sequence_position in [16, 1024, 2048]: + for q_len in [1, 3, 7]: + flash_attn_output, tree_attn_output = forward_attention( + batch_size=batch_size, + num_heads=num_heads, + num_kv_heads=2, + dim_per_head=128, + block_size=128, + max_sequence_length=8192, + sequence_position=sequence_position, + q_len=q_len, + backends=[FlashAttentionBackend, TreeAttentionBackend], + ) + assert torch.allclose( + flash_attn_output, tree_attn_output, atol=7.81e-3 + ), (f"outputs are not close for batch_size: {batch_size}, " + f"num_heads: {num_heads}, " + f"sequence_position: {sequence_position}, " + f"q_len: {q_len}.") From da6c40b24cca385e71138054ad8a521fc28b3f45 Mon Sep 17 00:00:00 2001 From: Giancarlo Delfin Date: Thu, 26 Jun 2025 17:50:19 -0700 Subject: [PATCH 3/4] [spec decoding] add tree attention backend selection Signed-off-by: Giancarlo Delfin --- vllm/attention/layer.py | 4 ++- vllm/attention/selector.py | 41 ++++++++++++++--------- vllm/engine/arg_utils.py | 1 + vllm/model_executor/models/llama.py | 4 +++ vllm/model_executor/models/llama_eagle.py | 2 +- vllm/platforms/cuda.py | 4 +++ vllm/platforms/interface.py | 1 + vllm/v1/attention/backends/flash_attn.py | 2 +- vllm/v1/attention/backends/tree_attn.py | 1 + 9 files changed, 42 insertions(+), 18 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 0c79aaf1355..5e6010aac99 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -52,6 +52,7 @@ def __init__( prefix: str = "", attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, + is_draft: bool = False, **extra_impl_args, ) -> None: """ @@ -135,7 +136,8 @@ def __init__( block_size, is_attention_free, blocksparse_params is not None, - use_mla=use_mla) + use_mla=use_mla, + is_draft=is_draft) impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index cb577fa6730..69da9943e9b 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -27,12 +27,12 @@ def backend_name_to_enum(backend_name: str) -> Optional[_Backend]: loaded. """ assert backend_name is not None - return _Backend[backend_name] if backend_name in _Backend.__members__ else \ - None + return _Backend[ + backend_name] if backend_name in _Backend.__members__ else None def get_env_variable_attn_backend() -> Optional[_Backend]: - ''' + """ Get the backend override specified by the vLLM attention backend environment variable, if one is specified. @@ -40,10 +40,9 @@ def get_env_variable_attn_backend() -> Optional[_Backend]: * _Backend enum value if an override is specified * None otherwise - ''' + """ backend_name = os.environ.get(STR_BACKEND_ENV_VAR) - return (None - if backend_name is None else backend_name_to_enum(backend_name)) + return None if backend_name is None else backend_name_to_enum(backend_name) # Global state allows a particular choice of backend @@ -57,7 +56,7 @@ def get_env_variable_attn_backend() -> Optional[_Backend]: def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None: - ''' + """ Force all attention operations to use a specified backend. Passing `None` for the argument re-enables automatic @@ -66,16 +65,16 @@ def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None: Arguments: * attn_backend: backend selection (None to revert to auto) - ''' + """ global forced_attn_backend forced_attn_backend = attn_backend def get_global_forced_attn_backend() -> Optional[_Backend]: - ''' + """ Get the currently-forced choice of attention backend, or None if auto-selection is currently enabled. - ''' + """ return forced_attn_backend @@ -87,6 +86,7 @@ def get_attn_backend( is_attention_free: bool, is_blocksparse: bool = False, use_mla: bool = False, + is_draft: bool = False, ) -> Type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" # Accessing envs.* behind an @lru_cache decorator can cause the wrong @@ -102,6 +102,7 @@ def get_attn_backend( is_blocksparse=is_blocksparse, use_v1=envs.VLLM_USE_V1, use_mla=use_mla, + is_draft=is_draft, ) @@ -115,11 +116,20 @@ def _cached_get_attn_backend( is_blocksparse: bool = False, use_v1: bool = False, use_mla: bool = False, + is_draft: bool = False, ) -> Type[AttentionBackend]: + # Draft model backend is currently forced to FlashAttentionBackend for + # consistency with EagleProposer using FlashAttentionMetadata. + if use_v1 and is_draft: + from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend + + return FlashAttentionBackend + if is_blocksparse: logger.info("Using BlocksparseFlashAttention backend.") from vllm.attention.backends.blocksparse_attn import ( BlocksparseFlashAttentionBackend) + return BlocksparseFlashAttentionBackend # If there are no attention layers (e.g. we are running Mamba), @@ -127,6 +137,7 @@ def _cached_get_attn_backend( if is_attention_free: from vllm.attention.backends.placeholder_attn import ( PlaceholderAttentionBackend) + return PlaceholderAttentionBackend # Check whether a particular choice of backend was @@ -135,8 +146,8 @@ def _cached_get_attn_backend( # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND # ENVIRONMENT VARIABLE. selected_backend = None - backend_by_global_setting: Optional[_Backend] = ( - get_global_forced_attn_backend()) + backend_by_global_setting: Optional[ + _Backend] = get_global_forced_attn_backend() if backend_by_global_setting is not None: selected_backend = backend_by_global_setting else: @@ -157,8 +168,8 @@ def _cached_get_attn_backend( @contextmanager def global_force_attn_backend_context_manager( - attn_backend: _Backend) -> Generator[None, None, None]: - ''' + attn_backend: _Backend, ) -> Generator[None, None, None]: + """ Globally force a vLLM attention backend override within a context manager, reverting the global attention backend override to its prior state upon exiting the context @@ -171,7 +182,7 @@ def global_force_attn_backend_context_manager( Returns: * Generator - ''' + """ # Save the current state of the global backend override (if any) original_value = get_global_forced_attn_backend() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6c908f88b9a..fb98acbfdd6 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1442,6 +1442,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: "ROCM_AITER_MLA", "TORCH_SDPA_VLLM_V1", "FLEX_ATTENTION", + "TREE_ATTN", ] if (envs.is_set("VLLM_ATTENTION_BACKEND") and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 5d5080479e5..b488c320a50 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -113,6 +113,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, prefix: str = "", attn_type: str = AttentionType.DECODER, + is_draft: bool = False, ) -> None: super().__init__() layer_idx = extract_layer_index(prefix) @@ -190,6 +191,7 @@ def __init__( per_layer_sliding_window=sliding_window, attn_type=attn_type, prefix=f"{prefix}.attn", + is_draft=is_draft, ) def forward( @@ -231,6 +233,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + is_draft: bool = False, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -275,6 +278,7 @@ def __init__( cache_config=cache_config, prefix=f"{prefix}.self_attn", attn_type=attn_type, + is_draft=is_draft, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index c7690604c1d..34f4b45e701 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -31,7 +31,7 @@ def __init__( disable_input_layernorm: bool, prefix: str = "", ) -> None: - super().__init__(config, prefix=prefix) + super().__init__(config, prefix=prefix, is_draft=True) # Skip the input_layernorm # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 879d094f657..2dbefb3c022 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -248,6 +248,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.info_once("Using Flash Attention backend on V1 engine.") return ("vllm.v1.attention.backends." "flash_attn.FlashAttentionBackend") + elif selected_backend == _Backend.TREE_ATTN: + logger.info_once("Using Tree Attention backend on V1 engine.") + return ("vllm.v1.attention.backends." + "tree_attn.TreeAttentionBackend") # Default backends for V1 engine # Prefer FlashInfer for Blackwell GPUs if installed diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 0f08bf98633..023dc6c78da 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -61,6 +61,7 @@ class _Backend(enum.Enum): DUAL_CHUNK_FLASH_ATTN = enum.auto() NO_ATTENTION = enum.auto() FLEX_ATTENTION = enum.auto() + TREE_ATTN = enum.auto() class PlatformEnum(enum.Enum): diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 527b3115341..56f3303b950 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -134,7 +134,7 @@ def _get_sliding_window_configs( sliding_window_configs: set[Optional[tuple[int, int]]] = set() layers = get_layers_from_vllm_config(vllm_config, Attention) for layer in layers.values(): - assert isinstance(layer.impl, FlashAttentionImpl) + assert hasattr(layer.impl, "sliding_window") sliding_window_configs.add(layer.impl.sliding_window) return sliding_window_configs diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 76ee7d2e765..71a0487b419 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -380,6 +380,7 @@ def __init__( None, # Skip KV reshape and cache. This class handles it. use_irope=use_irope, ) + self.sliding_window = self.prefill_attention_impl.sliding_window def forward( self, From b82c86b74548d2ed84ad346083a13a0e021dc51d Mon Sep 17 00:00:00 2001 From: Giancarlo Delfin Date: Thu, 10 Jul 2025 22:22:28 -0700 Subject: [PATCH 4/4] [spec decoding] implement proposing tree drafts Signed-off-by: Giancarlo Delfin --- pyproject.toml | 2 +- vllm/attention/layer.py | 4 +- vllm/attention/selector.py | 41 +- vllm/config.py | 7 + vllm/engine/arg_utils.py | 1 - vllm/model_executor/models/llama.py | 4 - vllm/model_executor/models/llama_eagle.py | 2 +- vllm/v1/attention/backends/flash_attn.py | 2 +- vllm/v1/attention/backends/tree_attn.py | 476 +++++++++++----------- vllm/v1/attention/backends/utils.py | 11 + vllm/v1/spec_decode/eagle.py | 283 +++++++++---- 11 files changed, 468 insertions(+), 365 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e8c2403af06..5926ab8c862 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ ignore_patterns = [ [tool.ruff] # Allow lines to be as long as 80. -line-length = 80 +line-length = 90 [tool.ruff.lint.per-file-ignores] "vllm/third_party/**" = ["ALL"] diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 5e6010aac99..0c79aaf1355 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -52,7 +52,6 @@ def __init__( prefix: str = "", attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, - is_draft: bool = False, **extra_impl_args, ) -> None: """ @@ -136,8 +135,7 @@ def __init__( block_size, is_attention_free, blocksparse_params is not None, - use_mla=use_mla, - is_draft=is_draft) + use_mla=use_mla) impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 69da9943e9b..cb577fa6730 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -27,12 +27,12 @@ def backend_name_to_enum(backend_name: str) -> Optional[_Backend]: loaded. """ assert backend_name is not None - return _Backend[ - backend_name] if backend_name in _Backend.__members__ else None + return _Backend[backend_name] if backend_name in _Backend.__members__ else \ + None def get_env_variable_attn_backend() -> Optional[_Backend]: - """ + ''' Get the backend override specified by the vLLM attention backend environment variable, if one is specified. @@ -40,9 +40,10 @@ def get_env_variable_attn_backend() -> Optional[_Backend]: * _Backend enum value if an override is specified * None otherwise - """ + ''' backend_name = os.environ.get(STR_BACKEND_ENV_VAR) - return None if backend_name is None else backend_name_to_enum(backend_name) + return (None + if backend_name is None else backend_name_to_enum(backend_name)) # Global state allows a particular choice of backend @@ -56,7 +57,7 @@ def get_env_variable_attn_backend() -> Optional[_Backend]: def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None: - """ + ''' Force all attention operations to use a specified backend. Passing `None` for the argument re-enables automatic @@ -65,16 +66,16 @@ def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None: Arguments: * attn_backend: backend selection (None to revert to auto) - """ + ''' global forced_attn_backend forced_attn_backend = attn_backend def get_global_forced_attn_backend() -> Optional[_Backend]: - """ + ''' Get the currently-forced choice of attention backend, or None if auto-selection is currently enabled. - """ + ''' return forced_attn_backend @@ -86,7 +87,6 @@ def get_attn_backend( is_attention_free: bool, is_blocksparse: bool = False, use_mla: bool = False, - is_draft: bool = False, ) -> Type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" # Accessing envs.* behind an @lru_cache decorator can cause the wrong @@ -102,7 +102,6 @@ def get_attn_backend( is_blocksparse=is_blocksparse, use_v1=envs.VLLM_USE_V1, use_mla=use_mla, - is_draft=is_draft, ) @@ -116,20 +115,11 @@ def _cached_get_attn_backend( is_blocksparse: bool = False, use_v1: bool = False, use_mla: bool = False, - is_draft: bool = False, ) -> Type[AttentionBackend]: - # Draft model backend is currently forced to FlashAttentionBackend for - # consistency with EagleProposer using FlashAttentionMetadata. - if use_v1 and is_draft: - from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend - - return FlashAttentionBackend - if is_blocksparse: logger.info("Using BlocksparseFlashAttention backend.") from vllm.attention.backends.blocksparse_attn import ( BlocksparseFlashAttentionBackend) - return BlocksparseFlashAttentionBackend # If there are no attention layers (e.g. we are running Mamba), @@ -137,7 +127,6 @@ def _cached_get_attn_backend( if is_attention_free: from vllm.attention.backends.placeholder_attn import ( PlaceholderAttentionBackend) - return PlaceholderAttentionBackend # Check whether a particular choice of backend was @@ -146,8 +135,8 @@ def _cached_get_attn_backend( # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND # ENVIRONMENT VARIABLE. selected_backend = None - backend_by_global_setting: Optional[ - _Backend] = get_global_forced_attn_backend() + backend_by_global_setting: Optional[_Backend] = ( + get_global_forced_attn_backend()) if backend_by_global_setting is not None: selected_backend = backend_by_global_setting else: @@ -168,8 +157,8 @@ def _cached_get_attn_backend( @contextmanager def global_force_attn_backend_context_manager( - attn_backend: _Backend, ) -> Generator[None, None, None]: - """ + attn_backend: _Backend) -> Generator[None, None, None]: + ''' Globally force a vLLM attention backend override within a context manager, reverting the global attention backend override to its prior state upon exiting the context @@ -182,7 +171,7 @@ def global_force_attn_backend_context_manager( Returns: * Generator - """ + ''' # Save the current state of the global backend override (if any) original_value = get_global_forced_attn_backend() diff --git a/vllm/config.py b/vllm/config.py index 84aa14b7c86..6d251e652c5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2724,6 +2724,13 @@ def __post_init__(self): f"num_speculative_tokens:{self.num_speculative_tokens}" f" must be divisible by {n_predict=}") + if self.speculative_token_tree is None: + # Generate chain of tokens. + self.speculative_token_tree = str([[ + (i + 1) * (0, ) + for i in range(self.num_speculative_tokens) + ]]) + self.draft_tensor_parallel_size = \ SpeculativeConfig._verify_and_get_draft_tp( self.target_parallel_config, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index fb98acbfdd6..bc7c90e9911 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1427,7 +1427,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: recommend_to_remove=False) return False - # No XFormers so far. V1_BACKENDS = [ "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index b488c320a50..5d5080479e5 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -113,7 +113,6 @@ def __init__( cache_config: Optional[CacheConfig] = None, prefix: str = "", attn_type: str = AttentionType.DECODER, - is_draft: bool = False, ) -> None: super().__init__() layer_idx = extract_layer_index(prefix) @@ -191,7 +190,6 @@ def __init__( per_layer_sliding_window=sliding_window, attn_type=attn_type, prefix=f"{prefix}.attn", - is_draft=is_draft, ) def forward( @@ -233,7 +231,6 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", - is_draft: bool = False, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -278,7 +275,6 @@ def __init__( cache_config=cache_config, prefix=f"{prefix}.self_attn", attn_type=attn_type, - is_draft=is_draft, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 34f4b45e701..c7690604c1d 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -31,7 +31,7 @@ def __init__( disable_input_layernorm: bool, prefix: str = "", ) -> None: - super().__init__(config, prefix=prefix, is_draft=True) + super().__init__(config, prefix=prefix) # Skip the input_layernorm # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 56f3303b950..527b3115341 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -134,7 +134,7 @@ def _get_sliding_window_configs( sliding_window_configs: set[Optional[tuple[int, int]]] = set() layers = get_layers_from_vllm_config(vllm_config, Attention) for layer in layers.values(): - assert hasattr(layer.impl, "sliding_window") + assert isinstance(layer.impl, FlashAttentionImpl) sliding_window_configs.add(layer.impl.sliding_window) return sliding_window_configs diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 71a0487b419..8f3b1834461 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -9,13 +9,15 @@ from xformers.ops.fmha import triton_splitk from xformers.ops.fmha.attn_bias import (AttentionBias, PagedBlockDiagonalPaddedKeysMask) -from xformers.ops.tree_attention import tree_attention +from xformers.ops.tree_attention import (_get_depth_counts, + _prepare_tree_attn_bias, + tree_attention) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) +from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.logger import init_logger -from vllm.v1.attention.backends.flash_attn import ( - FlashAttentionImpl, FlashAttentionMetadata, FlashAttentionMetadataBuilder) +from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec @@ -26,6 +28,8 @@ from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm import _custom_ops as ops + logger = init_logger(__name__) @@ -74,11 +78,87 @@ class TreeAttentionMetadata: seq_lens: torch.Tensor block_table: torch.Tensor slot_mapping: torch.Tensor - prefix_attn_bias: Optional[AttentionBias] - spec_attn_bias: Optional[torch.Tensor] - # Attention metadata for prefill. - prefill_attn_metadata: Optional[FlashAttentionMetadata] + num_prefill_tokens: int = 0 + num_decode_tokens: int = 0 + num_prefills: int = 0 + num_decodes: int = 0 + + prefix_attn_bias: Optional[AttentionBias] = None + spec_attn_bias: Optional[torch.Tensor] = None + + # Cached Prefill/decode metadata. + _cached_prefill_metadata: Optional["TreeAttentionMetadata"] = None + _cached_decode_metadata: Optional["TreeAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["TreeAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + # Recover cached prefill-phase attention + # metadata structure + return self._cached_prefill_metadata + + q_start_loc = self.query_start_loc[self.num_decodes:] + q_seqlens = torch.diff(q_start_loc) + kv_seqlens = self.seq_lens[self.num_decodes:] + # Construct & cache prefill-phase attention metadata structure + self._cached_prefill_metadata = TreeAttentionMetadata( + num_actual_tokens=self.num_prefill_tokens, + max_query_len=int(q_seqlens.max().item()), + query_start_loc=q_start_loc - q_start_loc[0], + max_seq_len=int(kv_seqlens.max().item()), + seq_lens=kv_seqlens, + block_table=self.block_table[self.num_decodes:], + slot_mapping=self.slot_mapping[self.num_decode_tokens:], + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["TreeAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + # Recover cached decode-phase attention + # metadata structure + return self._cached_decode_metadata + + q_start_loc = self.query_start_loc[:self.num_decodes + 1] + q_seqlens = torch.diff(q_start_loc) + kv_seqlens = self.seq_lens[:self.num_decodes] + # Construct & cache decode-phase attention metadata structure + self._cached_decode_metadata = TreeAttentionMetadata( + num_actual_tokens=self.num_decode_tokens, + max_query_len=int(q_seqlens.max().item()), + query_start_loc=q_start_loc, + max_seq_len=int(kv_seqlens.max().item()), + seq_lens=kv_seqlens, + block_table=self.block_table[:self.num_decodes], + slot_mapping=self.slot_mapping[:self.num_decode_tokens], + prefix_attn_bias=self.prefix_attn_bias, + spec_attn_bias=self.spec_attn_bias, + ) + return self._cached_decode_metadata + + @staticmethod + def from_eagle_attn_metadata( + flash_attn_metadata: "FlashAttentionMetadata", + ) -> "TreeAttentionMetadata": + num_prefills = flash_attn_metadata.query_start_loc.shape[0] - 1 + return TreeAttentionMetadata( + num_actual_tokens=flash_attn_metadata.num_actual_tokens, + num_prefill_tokens=flash_attn_metadata.num_actual_tokens, + num_prefills=num_prefills, + max_query_len=flash_attn_metadata.max_query_len, + query_start_loc=flash_attn_metadata.query_start_loc, + max_seq_len=flash_attn_metadata.max_seq_len, + seq_lens=flash_attn_metadata.seq_lens, + block_table=flash_attn_metadata.block_table, + slot_mapping=flash_attn_metadata.slot_mapping, + ) class TreeAttentionMetadataBuilder( @@ -97,17 +177,18 @@ def __init__( spec_config = runner.vllm_config.speculative_config spec_token_tree = spec_config.speculative_token_tree - self.tree_choices: list[tuple[int, ...]] = ( - ast.literal_eval(spec_token_tree) - if spec_token_tree is not None else []) - self.tree_size = len(self.tree_choices) + 1 - - self.prefill_attn_metadata_builder: FlashAttentionMetadataBuilder = ( - FlashAttentionMetadataBuilder( - runner, - kv_cache_spec, - block_table, - )) + tree_choices: list[tuple[int, + ...]] = (ast.literal_eval(spec_token_tree) if + spec_token_tree is not None else []) + # Construct the tree attention bias. + depth_counts = _get_depth_counts(tree_choices) + self.tree_attn_bias = _prepare_tree_attn_bias( + tree_choices, + depth_counts, + dtype=self.kv_cache_spec.dtype, + device=block_table.device, + ) + self.suffix_attn_bias = self.tree_attn_bias def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -120,7 +201,6 @@ def reorder_batch(self, input_batch: "InputBatch", decodes = [] prefills = [] num_decode_tokens = 0 - num_prefill_tokens = 0 for i, req_id in enumerate(input_batch.req_ids): num_tokens = scheduler_output.num_scheduled_tokens[req_id] @@ -128,12 +208,11 @@ def reorder_batch(self, input_batch: "InputBatch", # we should update this to something like < 8 in the future but # currently the decode run only supports num_tokens = 1 # For now, treat any decode step with exactly - if num_tokens == self.tree_size: + if num_tokens == self.suffix_attn_bias.shape[0]: decodes.append(i) num_decode_tokens += num_tokens else: prefills.append(i) - num_prefill_tokens += num_tokens # We hope that this is fairly minimal since decodes # should be around for a number of iterations so hopefully they are @@ -163,9 +242,7 @@ def reorder_batch(self, input_batch: "InputBatch", # TODO(lucas): this is a bit of a hack, we should probably have a # better way of doing this self._num_decodes = num_decodes - self._num_prefills = num_prefills self._num_decode_tokens = num_decode_tokens - self._num_prefill_tokens = num_prefill_tokens return modified_batch @@ -175,10 +252,10 @@ def build( ) -> TreeAttentionMetadata: num_reqs = common_attn_metadata.num_reqs num_decodes = self._num_decodes - num_prefills = self._num_prefills + num_prefills = num_reqs - num_decodes num_actual_tokens = common_attn_metadata.num_actual_tokens num_decode_tokens = self._num_decode_tokens - num_prefill_tokens = self._num_prefill_tokens + num_prefill_tokens = num_actual_tokens - num_decode_tokens q_start_loc = common_attn_metadata.query_start_loc q_seqlens = torch.diff(q_start_loc) max_query_len = common_attn_metadata.max_query_len @@ -187,32 +264,10 @@ def build( block_table = self.block_table slot_mapping = block_table.slot_mapping - # If there are any prefill requests, construct the prefill - # attention metadata. - prefill_attn_metadata = None - if num_prefills > 0: - # Temporarily set the block table slot mapping tensor to the - # slice for prefill. - block_table.slot_mapping = slot_mapping[num_decode_tokens:] - # Build prefill attention metadata. - prefill_attn_metadata = self.prefill_attn_metadata_builder.build( - common_prefix_len, - CommonAttentionMetadata( - query_start_loc=q_start_loc[num_decodes:] - - q_start_loc[num_decodes], - seq_lens=kv_seqlens[num_decodes:], - num_reqs=num_prefills, - num_actual_tokens=num_prefill_tokens, - max_query_len=int(q_seqlens[num_decodes:].max().item()), - ), - ) - # Restore block table slot mapping to the original, full tensor. - block_table.slot_mapping = slot_mapping - # Get the block table and slot mapping for paged KV. block_table_tensor = block_table.get_device_tensor()[:num_reqs] - slot_mapping[:num_decode_tokens].copy_( - block_table.slot_mapping_cpu[:num_decode_tokens], + slot_mapping[:num_actual_tokens].copy_( + block_table.slot_mapping_cpu[:num_actual_tokens], non_blocking=True, ) # Fill unused with -1. Needed for reshape_and_cache in full cuda graph @@ -220,7 +275,6 @@ def build( slot_mapping[num_actual_tokens:].fill_(-1) prefix_attn_bias = None - spec_attn_bias = None if num_decodes > 0: # Construct the prefix bias. decode_q_seqlens = q_seqlens[:num_decodes] @@ -233,15 +287,13 @@ def build( block_tables=block_table_tensor[:num_decodes], device=block_table.device, ) - # Construct the tree attention (suffix) bias. - spec_attn_bias = _prepare_tree_attn_bias( - self.tree_choices, - self.kv_cache_spec.dtype, - device=block_table.device, - ).T return TreeAttentionMetadata( num_actual_tokens=num_actual_tokens, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + num_decodes=num_decodes, max_query_len=max_query_len, query_start_loc=q_start_loc, max_seq_len=max_seq_len, @@ -249,88 +301,33 @@ def build( block_table=block_table_tensor, slot_mapping=slot_mapping, prefix_attn_bias=prefix_attn_bias, - spec_attn_bias=spec_attn_bias, - prefill_attn_metadata=prefill_attn_metadata, + spec_attn_bias=self.suffix_attn_bias, ) + def build_for_drafting( + self, + common_attn_metadata: CommonAttentionMetadata, + tree_level_offset: int, + ) -> TreeAttentionMetadata: + orig_num_decodes = self._num_decodes + orig_num_decode_tokens = self._num_decode_tokens + # While drafting, all requests are treated as decodes. + self._num_decodes = common_attn_metadata.num_reqs + self._num_decode_tokens = common_attn_metadata.num_actual_tokens + + # Slice the suffix attention bias so that + query_len = common_attn_metadata.max_query_len + start, end = tree_level_offset, tree_level_offset + query_len + self.suffix_attn_bias = self.tree_attn_bias[start:end, start:end] -def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]: - # Initialize depth_counts to keep track of how many choices have a - # particular depth. - depth_counts = [] - prev_depth = 0 - for path in sorted_tree_choices: - depth = len(path) - if depth != prev_depth: - depth_counts.append(0) - depth_counts[depth - 1] += 1 - prev_depth = depth - return depth_counts - - -def _prepare_tree_attn_bias( - sorted_tree_choices: list[tuple[int, ...]], - dtype: Optional[torch.dtype], - device: Optional[torch.device], -) -> torch.Tensor: - """ - Construct a Medusa-style tree attention bias as an explicit tensor. - It can be used as a spec_attn_bias ("right" or "suffix" attention part) - in tree_attention. See run_tree_attention_inner in test for a usage example. - Args: - sorted_tree_choices: tree description in the style of - https://github.com/FasterDecoding/Medusa/blob/5e9805386/medusa/model/medusa_choices.py - A typical tree description would look like: - [(node0, node1, ...), - (node0, node2), - (node0, node3), - (node1, node3), ..., - (node0, node2, ..., nodeN)] - Every tuple is corresponds to one node in the tree, encoded as a - path from one of the root nodes to the node in question. Passed - in sorted order. - - For example, a node encoded as (1, 0, 3, ..., 2) is understood as: - list all the root nodes and take node number 1 - list all children of that node and take node number 0 - list all children of that node and take node number 3 - ... - list all children of that node and take node number 2 - that's the - node encoded by this tuple - dtype: data type of the output tensor. - device: device of the output tensor. - Returns: - attention bias of shape (tree_size, tree_size), - where tree_size is the total number of nodes in the tree. - """ - depth_counts = _get_depth_counts(sorted_tree_choices) - - # +1 comes from the additional root node - tree_len = len(sorted_tree_choices) + 1 - tree_attn_mask = torch.full((tree_len, tree_len), - -torch.inf, - device=device, - dtype=dtype) - - mask_val = 0 - for i in range(tree_len): - tree_attn_mask[i, i] = mask_val - - tree_attn_mask[:, 0] = mask_val - start = 0 - for i in range(len(depth_counts)): - for j in range(depth_counts[i]): - cur_tree_choice = sorted_tree_choices[start + j] - # retrieve ancestor position - if len(cur_tree_choice) == 1: - continue - ancestor_idx = [] - for c in range(len(cur_tree_choice) - 1): - ancestor_idx.append( - sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1) - tree_attn_mask[j + start + 1, ancestor_idx] = mask_val - start += depth_counts[i] - return tree_attn_mask + # Build attention bias. + attn_metadata = self.build(0, common_attn_metadata) + + # Reset properties to original values. + self._num_decodes = orig_num_decodes + self._num_decode_tokens = orig_num_decode_tokens + self.suffix_attn_bias = self.tree_attn_bias + return attn_metadata class TreeAttentionImpl(AttentionImpl): @@ -357,6 +354,17 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + if sliding_window is None: + self.sliding_window = (-1, -1) + else: + self.sliding_window = (sliding_window - 1, 0) support_head_sizes = TreeAttentionBackend.get_supported_head_sizes() if head_size not in support_head_sizes: @@ -365,23 +373,6 @@ def __init__( f"Supported head sizes are: {support_head_sizes}. " "Set VLLM_USE_V1=0 to use another attention backend.") - self.prefill_attention_impl = FlashAttentionImpl( - num_heads=num_heads, - head_size=head_size, - scale=scale, - num_kv_heads=num_kv_heads, - alibi_slopes=alibi_slopes, - sliding_window=sliding_window, - kv_cache_dtype=kv_cache_dtype, - blocksparse_params=blocksparse_params, - logits_soft_cap=logits_soft_cap, - attn_type=attn_type, - kv_sharing_target_layer_name= - None, # Skip KV reshape and cache. This class handles it. - use_irope=use_irope, - ) - self.sliding_window = self.prefill_attention_impl.sliding_window - def forward( self, layer: torch.nn.Module, @@ -425,7 +416,7 @@ def forward( # and value[:num_actual_tokens] because the reshape_and_cache_flash # op uses the slot_mapping's shape to determine the number of # actual tokens. - torch.ops._C_cache_ops.reshape_and_cache_flash( + ops.reshape_and_cache_flash( key, value, key_cache, @@ -436,92 +427,99 @@ def forward( layer._v_scale, ) - num_decode_tokens = attn_metadata.num_actual_tokens - num_decodes = attn_metadata.query_start_loc.shape[0] - 1 - prefill_attn_metadata = attn_metadata.prefill_attn_metadata - if prefill_attn_metadata is not None: - num_decode_tokens -= prefill_attn_metadata.num_actual_tokens - num_decodes -= prefill_attn_metadata.query_start_loc.shape[0] - 1 - # Perform prefill flash attention. - self.prefill_attention_impl.forward( - layer, - query[num_decode_tokens:], - key[num_decode_tokens:], - value[num_decode_tokens:], - kv_cache, - prefill_attn_metadata, - output[num_decode_tokens:], - None, - ) - - if num_decodes == 0: - # No decode requests, abort early. - return output - - # Get only the speculatively decoded q, k, and vs. - spec_q = query[:num_decode_tokens] - spec_k = key[:num_decode_tokens] - spec_v = value[:num_decode_tokens] - - # Reshape q, k, and vs to [B, M, H, D]. - spec_q = spec_q.view(num_decodes, -1, self.num_heads, self.head_size) - spec_k = spec_k.view(num_decodes, -1, self.num_kv_heads, - self.head_size) - spec_v = spec_v.view(num_decodes, -1, self.num_kv_heads, - self.head_size) - # Reshape the KV cache to [Bkv, Mk, H, D] - cache_k = key_cache.view(1, -1, self.num_kv_heads, self.head_size) - cache_v = value_cache.view(1, -1, self.num_kv_heads, self.head_size) - - if self.num_kv_heads != self.num_heads: - # GQA/MQA. Reshape q, k, and v to [B, M, G, H, K]. - spec_q = spec_q.view( - spec_q.shape[0], - spec_q.shape[1], - self.num_kv_heads, - self.num_queries_per_kv, - spec_q.shape[-1], - ) - spec_k = spec_k[:, :, :, None, :].expand( - spec_k.shape[0], - spec_k.shape[1], - self.num_kv_heads, - self.num_queries_per_kv, - spec_k.shape[-1], - ) - spec_v = spec_v[:, :, :, None, :].expand( - spec_v.shape[0], - spec_v.shape[1], - self.num_kv_heads, - self.num_queries_per_kv, - spec_v.shape[-1], - ) - # Reshape the KV cache to [Bkv, Mk, G, H, K] - cache_k = cache_k[:, :, :, None, :].expand( - cache_k.shape[0], - cache_k.shape[1], - self.num_kv_heads, - self.num_queries_per_kv, - cache_k.shape[-1], - ) - cache_v = cache_v[:, :, :, None, :].expand( - cache_v.shape[0], - cache_v.shape[1], - self.num_kv_heads, - self.num_queries_per_kv, - cache_v.shape[-1], + num_actual_tokens = attn_metadata.num_actual_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + num_decodes = attn_metadata.num_decodes + if prefill_meta := attn_metadata.prefill_metadata: + descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, + key.shape[1]) + unified_attention( + q=query[num_decode_tokens:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[num_decode_tokens:num_actual_tokens], + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + seqused_k=prefill_meta.seq_lens, + max_seqlen_k=prefill_meta.max_seq_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=prefill_meta.block_table, + softcap=self.logits_soft_cap, + q_descale=None, # Not supported + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), ) - # Perform tree attention on the speculatively decoded tokens. - output[:num_decode_tokens] = tree_attention( - q=spec_q, - spec_k=spec_k, - spec_v=spec_v, - cache_k=cache_k, - cache_v=cache_v, - prefix_op=triton_splitk.FwOp, - suffix_op=triton_splitk.FwOp, - prefix_attn_bias=attn_metadata.prefix_attn_bias, - spec_attn_bias=attn_metadata.spec_attn_bias, - ).view(-1, self.num_heads, self.head_size) + if decode_meta := attn_metadata.decode_metadata: + # Get only the speculatively decoded q, k, and vs. + spec_q = query[:num_decode_tokens] + spec_k = key[:num_decode_tokens] + spec_v = value[:num_decode_tokens] + + # Reshape q, k, and vs to [B, M, H, D]. + spec_q = spec_q.view(num_decodes, -1, self.num_heads, + self.head_size) + spec_k = spec_k.view(num_decodes, -1, self.num_kv_heads, + self.head_size) + spec_v = spec_v.view(num_decodes, -1, self.num_kv_heads, + self.head_size) + # Reshape the KV cache to [Bkv, Mk, H, D] + cache_k = key_cache.view(1, -1, self.num_kv_heads, self.head_size) + cache_v = value_cache.view(1, -1, self.num_kv_heads, + self.head_size) + + if self.num_kv_heads != self.num_heads: + # GQA/MQA. Reshape q, k, and v to [B, M, G, H, K]. + spec_q = spec_q.view( + spec_q.shape[0], + spec_q.shape[1], + self.num_kv_heads, + self.num_queries_per_kv, + spec_q.shape[-1], + ) + spec_k = spec_k[:, :, :, None, :].expand( + spec_k.shape[0], + spec_k.shape[1], + self.num_kv_heads, + self.num_queries_per_kv, + spec_k.shape[-1], + ) + spec_v = spec_v[:, :, :, None, :].expand( + spec_v.shape[0], + spec_v.shape[1], + self.num_kv_heads, + self.num_queries_per_kv, + spec_v.shape[-1], + ) + # Reshape the KV cache to [Bkv, Mk, G, H, K] + cache_k = cache_k[:, :, :, None, :].expand( + cache_k.shape[0], + cache_k.shape[1], + self.num_kv_heads, + self.num_queries_per_kv, + cache_k.shape[-1], + ) + cache_v = cache_v[:, :, :, None, :].expand( + cache_v.shape[0], + cache_v.shape[1], + self.num_kv_heads, + self.num_queries_per_kv, + cache_v.shape[-1], + ) + + # Perform tree attention on the speculatively decoded tokens. + output[:num_decode_tokens] = tree_attention( + q=spec_q, + spec_k=spec_k, + spec_v=spec_v, + cache_k=cache_k, + cache_v=cache_v, + prefix_op=triton_splitk.FwOp, + suffix_op=triton_splitk.FwOp, + prefix_attn_bias=decode_meta.prefix_attn_bias, + spec_attn_bias=decode_meta.spec_attn_bias, + ).view(-1, self.num_heads, self.head_size) return output diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 8083f200260..7bbf0edb80f 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -77,6 +77,17 @@ def build_for_cudagraph_capture( return self.build(common_prefix_len=0, common_attn_metadata=common_attn_metadata) + def build_for_drafting( + self, + common_attn_metadata: CommonAttentionMetadata, + tree_level_offset: int, + ) -> M: + """ + Build attention metadata for draft model. Uses build by default. + """ + return self.build(common_prefix_len=0, + common_attn_metadata=common_attn_metadata) + def use_cascade_attention( self, common_prefix_len: int, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 156f5764e8d..89ad1d6d01c 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ast + import torch import torch.nn as nn @@ -14,6 +16,8 @@ from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata, FlashAttentionMetadata) +from vllm.v1.attention.backends.tree_attn import (TreeAttentionBackend, + TreeAttentionMetadata) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel @@ -41,10 +45,8 @@ def __init__( self.dtype = vllm_config.model_config.dtype self.max_model_len = vllm_config.model_config.max_model_len self.block_size = vllm_config.cache_config.block_size - self.num_speculative_tokens = ( - self.speculative_config.num_speculative_tokens) - self.max_num_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens) + self.num_speculative_tokens = self.speculative_config.num_speculative_tokens + self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens # We need to get the hidden size from the draft model config because # the draft model's hidden size can be different from the target model's # hidden size (e.g., Llama 3.3 70B). @@ -68,12 +70,52 @@ def __init__( (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device) - # We need +1 here because the arange is used to set query_start_loc, - # which has one more element than batch_size. - self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + - 1, - device=device, - dtype=torch.int32) + + max_batch_size = vllm_config.scheduler_config.max_num_seqs + self.arange = torch.arange( + # We need +1 here because the arange is used to set query_start_loc, + # which has one more element than batch_size. + max_batch_size + 1, + device=device, + dtype=torch.int32, + ) + + # Parse the speculative token tree. + spec_token_tree = self.speculative_config.speculative_token_tree + self.tree_choices: list[tuple[int, + ...]] = ast.literal_eval(spec_token_tree) + tree_depth = len(self.tree_choices[-1]) + # Precompute per-level properties of the tree. + num_drafts_per_level = [0] * tree_depth + for node in self.tree_choices: + num_drafts_per_level[len(node) - 1] += 1 + self.cu_drafts_per_level = [num_drafts_per_level[0]] + self.child_drafts_per_level = [num_drafts_per_level[0]] + for level in range(1, tree_depth): + self.cu_drafts_per_level.append(self.cu_drafts_per_level[-1] + + num_drafts_per_level[level]) + self.child_drafts_per_level.append(num_drafts_per_level[level] // + num_drafts_per_level[level - 1]) + # Find the first level where the tree branches off into one or more + # children. + self.first_branching_level = None + for level in range(tree_depth): + if self.cu_drafts_per_level[level] > level + 1: + self.first_branching_level = level + break + # Precompute draft position offsets in flattened tree. + self.tree_draft_pos_offsets = torch.arange( + 1, + len(self.tree_choices) + 1, + device=device, + dtype=torch.int32, + ).repeat(max_batch_size, 1) + + def get_padded_num_input_tokens(self, num_tokens: int): + return (self.vllm_config.pad_for_cudagraph(num_tokens) + if self.use_cuda_graph + and num_tokens <= self.cudagraph_batch_sizes[-1] else + num_tokens) def propose( self, @@ -133,6 +175,11 @@ def propose( prefix_kv_lens=None, suffix_kv_lens=None, ) + + if (self.runner is not None + and self.runner.attn_backends[0] == TreeAttentionBackend): + attn_metadata = TreeAttentionMetadata.from_eagle_attn_metadata( + attn_metadata) elif self.method == "deepseek_mtp": query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] max_query_len = query_lens.max().item() @@ -160,11 +207,7 @@ def propose( per_layer_attn_metadata = {} for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata - if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: - num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) - else: - num_input_tokens = num_tokens + num_input_tokens = self.get_padded_num_input_tokens(num_tokens) # copy inputs to buffer for cudagraph self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states @@ -183,10 +226,10 @@ def propose( last_hidden_states, hidden_states = ret_hidden_states sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) - draft_token_ids = logits.argmax(dim=-1) # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1: + draft_token_ids = logits.argmax(dim=-1) # [batch_size, 1] return draft_token_ids.view(-1, 1) @@ -194,41 +237,98 @@ def propose( # one layer. Adapt this code to support multiple layers once # there's a multi-layer MTP module. - # Generate the remaining draft tokens. + # Sample a draft token for each tree child at the root level. + draft_token_ids = torch.topk(logits, + self.child_drafts_per_level[0], + dim=-1).indices + + # Prepare to generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] + draft_last_hidden_states = hidden_states[last_token_indices].view( + batch_size, 1, -1) - positions = target_positions[last_token_indices] - hidden_states = hidden_states[last_token_indices] - if self.use_cuda_graph and \ - batch_size <= self.cudagraph_batch_sizes[-1]: - input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) - else: - input_batch_size = batch_size + # Setup attention metadata for drafting. attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange[:batch_size + 1] - for _ in range(self.num_speculative_tokens - 1): - # Update the inputs. - # cast to int32 is crucial when eagle model is compiled. - # tensor.argmax() returns int64 by default. - input_ids = draft_token_ids_list[-1].int() - positions += 1 - - # NOTE(woosuk): We should handle the case where the draft model - # generates tokens beyond the max model length. Since it is complex - # to remove such requests from the batch, we keep them in the batch - # but adjust the position ids and slot mappings to avoid the - # out-of-range access during the model execution. The draft tokens - # generated with this adjustment should be ignored. - exceeds_max_model_len = positions >= self.max_model_len + + # Initialize empty tensors for concatenation with the level outputs. + input_ids = torch.empty(0, + device=self.input_ids.device, + dtype=self.input_ids.dtype) + positions = torch.empty(0, + device=self.positions.device, + dtype=self.positions.dtype) + hidden_states = torch.empty(0, + device=self.hidden_states.device, + dtype=self.hidden_states.dtype) + # Precompute the draft token positions. + base_positions = target_positions[last_token_indices] + flattened_draft_positions = ( + base_positions.view(batch_size, -1) + + self.tree_draft_pos_offsets[:batch_size, :]) + total_num_drafts = 0 + tree_depth = len(self.cu_drafts_per_level) + for level in range(tree_depth - 1): + num_level_drafts = self.cu_drafts_per_level[ + level] - total_num_drafts + total_num_drafts = self.cu_drafts_per_level[level] + + # Get draft positions for RoPE. + draft_positions = base_positions + (level + 1) + exceeds_max_model_len = (base_positions + + total_num_drafts) >= self.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. - clamped_positions = torch.where(exceeds_max_model_len, 0, - positions) + clamped_draft_positions = torch.where( + exceeds_max_model_len, + 0, + draft_positions, + ) + draft_positions = clamped_draft_positions.repeat_interleave( + num_level_drafts).reshape(batch_size, -1) + + # Broadcast draft hidden states for each child. + draft_hidden_states = draft_last_hidden_states.repeat_interleave( + self.child_drafts_per_level[level], dim=1) + + if (self.first_branching_level is not None + and level >= self.first_branching_level): + # Draft branching has occurred. Tree attention must be used to + # predict subsequent draft tokens. + query_len = total_num_drafts - self.first_branching_level + input_ids = torch.cat([input_ids, draft_token_ids], dim=1) + positions = torch.cat([positions, draft_positions], dim=1) + hidden_states = torch.cat([hidden_states, draft_hidden_states], + dim=1) + + # Build new attention metadata for the next level of drafts. + # This is necessary to support tree attention. + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=query_len * self.arange[:batch_size + 1], + seq_lens=attn_metadata.seq_lens + num_level_drafts, + num_reqs=batch_size, + num_actual_tokens=batch_size * query_len, + max_query_len=query_len, + ) + attn_metadata = self.runner.attn_metadata_builders[ + 0].build_for_drafting( + common_attn_metadata=common_attn_metadata, + tree_level_offset=self.first_branching_level, + ) + + # Apply new attention metadata to all layers. + per_layer_attn_metadata = {} + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata + else: + # Still a chain of drafts. Continue performing standard attention, + # which is more efficient than tree attention. + query_len = 1 + input_ids = draft_token_ids + positions = draft_positions + hidden_states = draft_hidden_states - # Increment the sequence lengths. - attn_metadata.max_seq_len += 1 - attn_metadata.seq_lens += 1 # Consider max model length. attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, self.max_model_len) @@ -237,43 +337,53 @@ def propose( attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) # Compute the slot mapping. - block_numbers = clamped_positions // self.block_size - block_ids = block_table.gather(dim=1, - index=block_numbers.view(-1, 1)) - block_ids = block_ids.view(-1) - attn_metadata.slot_mapping = (block_ids * self.block_size + - clamped_positions % self.block_size) + query_positions = flattened_draft_positions[:, level:level + + query_len] + block_numbers = query_positions // self.block_size + block_ids = block_table.gather(dim=1, index=block_numbers) + slot_mapping = (block_ids * self.block_size + + query_positions % self.block_size) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. - attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len, - PADDING_SLOT_ID) + slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID + attn_metadata.slot_mapping = slot_mapping.view(-1) - # copy inputs to buffer for cudagraph - self.input_ids[:batch_size] = input_ids - self.positions[:batch_size] = clamped_positions - self.hidden_states[:batch_size] = hidden_states + # Copy inputs to buffer for cudagraph. + num_tokens = attn_metadata.num_actual_tokens + self.input_ids[:num_tokens] = input_ids.view(-1) + self.positions[:num_tokens] = positions.view(-1) + self.hidden_states[:num_tokens] = hidden_states.view( + num_tokens, -1) + num_input_tokens = self.get_padded_num_input_tokens(num_tokens) # Run the model. with set_forward_context(per_layer_attn_metadata, self.vllm_config, - num_tokens=input_batch_size): - last_hidden_states, hidden_states = self.model( - self.input_ids[:input_batch_size], - self.positions[:input_batch_size], - self.hidden_states[:input_batch_size], + num_tokens=num_input_tokens): + last_hidden_states, _ = self.model( + self.input_ids[:num_input_tokens], + self.positions[:num_input_tokens], + self.hidden_states[:num_input_tokens], ) - hidden_states = hidden_states[:batch_size] - logits = self.model.compute_logits(last_hidden_states[:batch_size], - None) - # TODO(wenlong): get more than one token for tree attention - draft_token_ids = logits.argmax(dim=-1) + # Get the last hidden states for predicting the drafts. + draft_last_hidden_states = last_hidden_states[:num_tokens].view( + batch_size, query_len, -1)[:, -num_level_drafts:] + + # Sample a draft token for each child at the current tree level. + logits = self.model.compute_logits( + draft_last_hidden_states.reshape(batch_size * num_level_drafts, + -1), None) + draft_token_ids = torch.topk(logits, + self.child_drafts_per_level[level + + 1], + dim=-1).indices.view(batch_size, -1) draft_token_ids_list.append(draft_token_ids) # [batch_size, num_speculative_tokens] - draft_token_ids = torch.stack(draft_token_ids_list, dim=1) - return draft_token_ids + output = torch.cat(draft_token_ids_list, dim=1) + return torch.cat(draft_token_ids_list, dim=1) @staticmethod def prepare_inputs( @@ -292,8 +402,7 @@ def prepare_inputs( # a + b, a + b + 1, ..., a + b + c - n3 - 1] # [0, a, a + b, a + b + c] -> [a, b, c] - query_len_per_req = (cu_target_query_lens[1:] - - cu_target_query_lens[:-1]) + query_len_per_req = cu_target_query_lens[1:] - cu_target_query_lens[:-1] # [a, b, c] -> [a - n1, b - n2, c - n3] num_tokens_per_req = query_len_per_req - num_rejected_tokens @@ -317,12 +426,12 @@ def prepare_inputs( return cu_num_tokens, token_indices def load_model(self, target_model: nn.Module) -> None: - draft_model_config = \ - self.vllm_config.speculative_config.draft_model_config + draft_model_config = self.vllm_config.speculative_config.draft_model_config target_attn_layer_names = set( get_layers_from_vllm_config(self.vllm_config, Attention).keys()) from vllm.compilation.backends import set_model_tag + with set_model_tag("eagle_head"): self.model = get_model(vllm_config=self.vllm_config, model_config=draft_model_config) @@ -335,33 +444,29 @@ def load_model(self, target_model: nn.Module) -> None: if supports_multimodal(target_model): # handle multimodality - self.model.config.image_token_index = ( - target_model.config.image_token_index) + self.model.config.image_token_index = target_model.config.image_token_index target_language_model = target_model.get_language_model() else: target_language_model = target_model # share embed_tokens with the target model if needed - if get_pp_group().world_size == 1 \ - and self.model.model.embed_tokens.weight.shape \ - == target_language_model.model.embed_tokens.weight.shape: + if (get_pp_group().world_size == 1 + and self.model.model.embed_tokens.weight.shape + == target_language_model.model.embed_tokens.weight.shape): logger.info( - "Assuming the EAGLE head shares the same vocab embedding" \ - " with the target model." - ) + "Assuming the EAGLE head shares the same vocab embedding" + " with the target model.") del self.model.model.embed_tokens - self.model.model.embed_tokens = ( - target_language_model.model.embed_tokens) + self.model.model.embed_tokens = target_language_model.model.embed_tokens else: logger.info( - "The EAGLE head's vocab embedding will be loaded separately" \ - " from the target model." - ) + "The EAGLE head's vocab embedding will be loaded separately" + " from the target model.") # share lm_head with the target model if needed # some model definition do not define lm_head explicitly # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM - if self.vllm_config.speculative_config.method != "eagle3" and \ - hasattr(target_language_model, "lm_head"): + if self.vllm_config.speculative_config.method != "eagle3" and hasattr( + target_language_model, "lm_head"): logger.info("Loading EAGLE LM head weights from the target model.") self.model.lm_head = target_language_model.lm_head @@ -390,12 +495,12 @@ def validate_same_kv_cache_group(self, for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): for layer_name in kv_cache_group.layer_names: kv_cache_groups[layer_name] = id - assert len( + assert (len( set([ kv_cache_groups[layer_name] for layer_name in self.attn_layer_names - ]) - ) == 1, "All eagle layers should belong to the same kv cache group" + ])) == 1 + ), "All eagle layers should belong to the same kv cache group" # NOTE(woosuk): Currently, the below code is not used and we always use argmax