Skip to content

[Perf][Spec Decode] EAGLE Kernel Fusion + Synchronization Overhead Reduction #20078

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 127 additions & 81 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.triton_utils import triton
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
FlashAttentionMetadata)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel
from vllm.v1.spec_decode.utils import (advance_state_kernel,
prepare_eagle_input_kernel)

logger = init_logger(__name__)

Expand Down Expand Up @@ -75,6 +77,15 @@ def __init__(
device=device,
dtype=torch.int32)

# Used to store precomputed values from load_model()
# so they can be used in propose()
self.last_token_indices = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=device)
self.seq_lens = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=device)

def propose(
self,
# [num_tokens]
Expand All @@ -92,40 +103,21 @@ def propose(
# [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor,
sampling_metadata: SamplingMetadata,
num_tokens: int,
max_num_tokens: int,
max_seq_len: int,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1

if self.method == "eagle3":
assert isinstance(self.model, Eagle3LlamaForCausalLM)
target_hidden_states = self.model.combine_hidden_states(
target_hidden_states)
assert target_hidden_states.shape[-1] == self.hidden_size

# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[last_token_indices] = next_token_ids

# FA requires seq_len to have dtype int32.
seq_lens = (target_positions[last_token_indices] + 1).int()

if self.method in ["eagle", "eagle3"]:
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_seq_len = seq_lens.max().item()
max_num_tokens = (cu_num_tokens[1:] -
cu_num_tokens[:-1]).max().item()
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_tokens,
max_query_len=max_num_tokens,
query_start_loc=cu_num_tokens,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
seq_lens=self.seq_lens,
block_table=block_table,
slot_mapping=target_slot_mapping,
slot_mapping=target_slot_mapping[:num_tokens],
# TODO(woosuk): Support cascade attention.
use_cascade=False,
common_prefix_len=0,
Expand All @@ -134,15 +126,12 @@ def propose(
suffix_kv_lens=None,
)
elif self.method == "deepseek_mtp":
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
max_query_len = query_lens.max().item()

common_attn_metadata = CommonAttentionMetadata(
query_start_loc=cu_num_tokens,
seq_lens=seq_lens,
seq_lens=self.seq_lens,
num_reqs=batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
max_query_len=self.max_num_tokens,
)

assert self.runner is not None
Expand All @@ -165,9 +154,6 @@ def propose(
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
num_input_tokens = num_tokens
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states

with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
Expand All @@ -181,7 +167,7 @@ def propose(
last_hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices]
sample_hidden_states = last_hidden_states[self.last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids = logits.argmax(dim=-1)

Expand All @@ -197,8 +183,8 @@ def propose(
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]

positions = target_positions[last_token_indices]
hidden_states = hidden_states[last_token_indices]
positions = target_positions[self.last_token_indices]
hidden_states = hidden_states[self.last_token_indices]
if self.use_cuda_graph and \
batch_size <= self.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
Expand All @@ -208,52 +194,12 @@ def propose(
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
for _ in range(self.num_speculative_tokens - 1):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
input_ids = draft_token_ids_list[-1].int()
positions += 1

# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
exceeds_max_model_len = positions >= self.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)

# Increment the sequence lengths.
attn_metadata.max_seq_len += 1
attn_metadata.seq_lens += 1
# Consider max model length.
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
self.max_model_len)
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)

# Compute the slot mapping.
block_numbers = clamped_positions // self.block_size
block_ids = block_table.gather(dim=1,
index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
attn_metadata.slot_mapping = (block_ids * self.block_size +
clamped_positions % self.block_size)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
PADDING_SLOT_ID)

# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
self.hidden_states[:batch_size] = hidden_states
self.advance_speculative_state(draft_token_ids_list[-1], positions,
hidden_states, attn_metadata,
batch_size)

# copy inputs to buffer for cudagraph
# Run the model.
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
Expand All @@ -275,6 +221,58 @@ def propose(
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids

def advance_speculative_state(self, draft_token_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
batch_size: int):
"""
Advances the speculative decoding state and metadata by one step

Parameters:
----------
draft_token_ids (torch.Tensor): Token IDs generated by the draft model
positions (torch.Tensor): Position indices for the draft tokens
hidden_states (torch.Tensor): Corresponding hidden states for the tokens
attn_metadata (FlashAttentionMetadata): FlashAttention metadata
batch_size (int): Number of sequences to update.
"""

# Calculate number of thread blocks
grid = lambda meta: (triton.cdiv(batch_size, meta['BLOCK_SIZE']), )
attn_metadata.slot_mapping = torch.empty_like(positions)
advance_state_kernel[grid](
# === Input tensors ===
draft_token_ids,
positions,

# === Model input buffers to be updated ===
self.input_ids[:batch_size],
self.positions[:batch_size],

# === Metadata tensors ===
attn_metadata.seq_lens,
attn_metadata.block_table,
attn_metadata.slot_mapping,

# === Scalar configuration ===
self.max_model_len,
self.block_size,
self.max_model_len // self.block_size,

# === Execution control ===
batch_size,
BLOCK_SIZE=1024,
PADDING_SLOT_ID=PADDING_SLOT_ID)

self.hidden_states[:batch_size] = hidden_states

# Increment the sequence lengths.
attn_metadata.max_seq_len += 1
# Consider max model length.
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
self.max_model_len)

@staticmethod
def prepare_inputs(
# [batch_size + 1]
Expand All @@ -301,7 +299,7 @@ def prepare_inputs(
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
cu_num_tokens = torch.zeros_like(cu_target_query_lens)
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
token_indices = torch.empty(
token_indices = torch.zeros(
num_tokens,
dtype=torch.int32,
device=cu_target_query_lens.device,
Expand All @@ -316,6 +314,54 @@ def prepare_inputs(
)
return cu_num_tokens, token_indices

def load_inputs(self, target_token_ids: torch.Tensor,
target_positions: torch.Tensor,
target_hidden_states: torch.Tensor,
next_token_ids_gpu: torch.Tensor,
cu_num_tokens: torch.Tensor, num_scheduled_tokens: int):
"""
Loads token ids, positions, etc. into the eagle model

Logic moved from EagleProposer.propose() to here

Parameters:
----------
target_token_ids (torch.Tensor): Draft-step token IDs
target_positions (torch.Tensor): Token Position indices
target_hidden_states (torch.Tensor): Token hidden states
next_token_ids_gpu (torch.Tensor): Sampled final token IDs
cu_num_tokens (torch.Tensor): Cumulative tokens from prepare_inputs()
num_scheduled_tokens (int): Total number of tokens scheduled
"""

self.last_token_indices = cu_num_tokens[1:] - 1

# FA requires seq_len to have dtype int32.
self.seq_lens = (target_positions[self.last_token_indices] + 1).int()

if self.method == "eagle3":
assert isinstance(self.model, Eagle3LlamaForCausalLM)
target_hidden_states = self.model.combine_hidden_states(
target_hidden_states)
assert target_hidden_states.shape[-1] == self.hidden_size

# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[:num_scheduled_tokens -
1] = target_token_ids[:num_scheduled_tokens][1:]

# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[self.last_token_indices] = next_token_ids_gpu

# copy inputs to buffer for cudagraph
self.positions[:
num_scheduled_tokens] = target_positions[:
num_scheduled_tokens]
self.hidden_states[:
num_scheduled_tokens] = target_hidden_states[:
num_scheduled_tokens]

def load_model(self, target_model: nn.Module) -> None:
draft_model_config = \
self.vllm_config.speculative_config.draft_model_config
Expand Down
79 changes: 79 additions & 0 deletions vllm/v1/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,82 @@ def prepare_eagle_input_kernel(
index_start + offset,
mask=offset < num_tokens,
)


@triton.jit
def advance_state_kernel(
draft_token_ids_ptr,
positions_ptr,

# === Model input buffers to be updated ===
model_input_ids_ptr,
model_positions_ptr,

# === Metadata tensors ===
seq_lens_ptr,
block_table_ptr,
slot_mapping_ptr,

# === Scalar configuration ===
model_max_len: int,
model_block_size: int,
model_block_stride: int,

# === Execution control ===
n_elements: int,
BLOCK_SIZE: tl.constexpr,
PADDING_SLOT_ID: tl.constexpr,
):
# Triton kernel to perform draft model state advancement.

pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
draft_token_list_last = tl.load(draft_token_ids_ptr + offsets, mask=mask)
position = tl.load(positions_ptr + offsets, mask=mask)
seq_lens = tl.load(seq_lens_ptr + offsets, mask=mask)

# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
input_id = draft_token_list_last.cast(tl.int32)
position = position + 1

# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
exceeds_max_model_len = position >= model_max_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_position = tl.where(exceeds_max_model_len, 0, position)

# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
seq_lens += 1
seq_lens = tl.where(exceeds_max_model_len, 1, seq_lens)

block_numbers = clamped_position // model_block_size
block_offsets = clamped_position % model_block_size

block_ids = tl.load(block_table_ptr + model_block_stride * offsets +
block_numbers,
mask=mask)

# Compute slot mapping
slot_mapping = block_ids * model_block_size + block_offsets

# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
slot_mapping = tl.where(exceeds_max_model_len, PADDING_SLOT_ID,
slot_mapping)

tl.store(model_input_ids_ptr + offsets, input_id, mask=mask)
tl.store(positions_ptr + offsets, position, mask=mask)
tl.store(model_positions_ptr + offsets, clamped_position, mask=mask)
tl.store(seq_lens_ptr + offsets, seq_lens, mask=mask)
tl.store(slot_mapping_ptr + offsets, slot_mapping, mask=mask)
Loading