Skip to content

[V1][Spec Decode] Non greedy sample with EAGLE / Reduce memory allocation for Rejection Sampler #16077

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
130 changes: 99 additions & 31 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,34 @@ def __init__(
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs,
device=device)

max_batch_size = vllm_config.scheduler_config.max_num_seqs
vocab_size = vllm_config.model_config.get_vocab_size()

# setup buffers for draft token ids and probs to be
# reused across steps for Rejection Sampling
self.curr_batch_size = -1
self._draft_token_ids_buffer = torch.zeros(max_batch_size,
self.num_speculative_tokens,
dtype=torch.int32,
device=device)
self._draft_token_ids_buffer_shape = self._draft_token_ids_buffer.shape
self._draft_probs_buffer = torch.zeros(max_batch_size,
self.num_speculative_tokens,
vocab_size,
dtype=torch.float32,
device=device)
self._draft_probs_buffer_shape = self._draft_probs_buffer.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just want to point this out, we might need to fix it. For the draft_probs_buffer, it has size (plug in numbers of llama3-8B):
256 * 10 * 128256 * 4 / 1024 / 1024 = 1.3G
It has a low probability that this might trigger OOM if we do this after vLLM preallocates all memory for kv cache. But it should not be a big problem.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to allocate this before vLLM preallocates memory for KVC?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The preallocation of RS buffer is happening here which is when gpuModelRunner is created. Could you point me to which line of code computes the available GPU memory and allocated the KVC on that?


def get_draft_token_ids(self) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have the get_draft_token_ids API, I'm wondering if it might be cleaner to move all proposing logic (https://github.com/vllm-project/vllm/blob/660a6b0ed756bb7ca0459786fd8302b9ede2c280/vllm/v1/worker/gpu_model_runner.py#L1171C8-L1229C14) under this function?

Copy link
Contributor Author

@ekagra-ranjan ekagra-ranjan Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the get_draft_token_ids is API is more like a getter for the right slice of the preallocated buffer so repeated calls will just give the handle to the buffer. If we move the proposer logic here then repeated calls will propose again. We could refactor the code and add the section under a new API if that makes sense.

# [batch_size, num_speculative_tokens]
assert self.curr_batch_size != -1, "EagleProposer hasn't proposed yet."
return self._draft_token_ids_buffer[:self.curr_batch_size, :]

def get_draft_probs(self) -> torch.Tensor:
# [batch_size, num_speculative_tokens, vocab_size]
assert self.curr_batch_size != -1, "EagleProposer hasn't proposed yet."
return self._draft_probs_buffer[:self.curr_batch_size, :, :]

def propose(
self,
# [num_tokens]
Expand All @@ -41,9 +69,32 @@ def propose(
# [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
):
# make sure that the buffers size has not changed
# by any future operation
assert self._draft_probs_buffer_shape.numel(
) == self._draft_probs_buffer.numel(
), "Size of self._draft_probs_buffer has been changed. "
"Make sure it remaiins the same."

assert self._draft_token_ids_buffer_shape.numel(
) == self._draft_token_ids_buffer.numel(
), "Size of self._draft_token_ids_buffer has been changed. "
"Make sure it remaiins the same."

# restore shape of buffers if it has been
# changed by any future operation
if (self._draft_probs_buffer.shape != self._draft_probs_buffer_shape):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why might this buffer be reshaped by any operation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are passing the buffer outside of the function, the caller gets the handle of this buffer and it might accidentally do a reshape. I am assuming that someone might in future do it since its not obvious that they shouldnt do it. The check will help in those cases. Let me know if this check should be removed

self._draft_probs_buffer.reshape(self._draft_probs_buffer_shape)

if (self._draft_token_ids_buffer.shape
!= self._draft_token_ids_buffer_shape):
self._draft_token_ids_buffer.reshape(
self._draft_token_ids_buffer_shape)

num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
self.curr_batch_size = batch_size
last_token_indices = cu_num_tokens[1:] - 1

input_ids = torch.empty_like(target_token_ids)
Expand Down Expand Up @@ -82,26 +133,25 @@ def propose(
)
sample_hidden_states = hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids, draft_probs = compute_probs_and_sample_next_token(
logits, sampling_metadata)
compute_probs_and_sample_next_token(logits, sampling_metadata, 0,
batch_size,
self._draft_token_ids_buffer,
self._draft_probs_buffer)

# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
# [batch_size, 1] and [batch_size, 1, vocab_size]
return draft_token_ids.view(-1, 1), draft_probs.unsqueeze(dim=1)
return

# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
draft_probs_list = [draft_probs]

positions = target_positions[last_token_indices]
hidden_states = sample_hidden_states
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size]
for _ in range(self.num_speculative_tokens - 1):
for speculative_token_idx in range(self.num_speculative_tokens - 1):
# Update the inputs.
input_ids = draft_token_ids_list[-1]
input_ids = self._draft_token_ids_buffer[:batch_size,
speculative_token_idx]
positions += 1
attn_metadata.max_seq_len += 1
attn_metadata.seq_lens += 1
Expand All @@ -121,16 +171,11 @@ def propose(
positions=positions,
)
logits = self.model.compute_logits(hidden_states, None)
draft_token_ids, probs = compute_probs_and_sample_next_token(
logits, sampling_metadata)
draft_token_ids_list.append(draft_token_ids)
draft_probs_list.append(probs)

# [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
# [batch_size, num_speculative_tokens, vocab_size]
draft_probs = torch.stack(draft_probs_list, dim=1)
return draft_token_ids, draft_probs
compute_probs_and_sample_next_token(logits, sampling_metadata,
speculative_token_idx + 1,
batch_size,
self._draft_token_ids_buffer,
self._draft_probs_buffer)

@staticmethod
def prepare_inputs(
Expand Down Expand Up @@ -203,36 +248,59 @@ def forward(
def compute_probs_and_sample_next_token(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
speculative_token_idx: int,
batch_size: int,
# [batch_size, num_speculative_tokens]
draft_token_ids_buffer: torch.Tensor,
# [batch_size, num_speculative_tokens, vocab_size]
draft_probs_buffer: torch.Tensor,
):
# We pass in the entire preallocated buffers draft_token_ids_buffer
# and draft_probs_buffer and select the portion of the buffer that
# we need to fill in using batch_size and speculative_token_idx.
# This allows us to write in-place. If we passed in the specific
# tensors slices directly to func, i.e.,
# draft_token_ids_buffer[:batch_size, speculative_token_idx]
# as draft_token_ids, then draft_token_ids = logits.argmax(dim=-1)
# would create a new tensor and not allow in-place writes.

if sampling_metadata.all_greedy:
# For greedy requests, draft_probs is not used in rejection sampling.
# Therefore, we can just return the logits.
probs = logits
next_token_ids = logits.argmax(dim=-1)
return next_token_ids, probs
draft_probs_buffer[:batch_size, speculative_token_idx, :] = logits
draft_token_ids_buffer[:batch_size,
speculative_token_idx] = logits.argmax(dim=-1)

is_greedy = sampling_metadata.temperature == -1
temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
logits.div_(temperature.view(-1, 1))
probs = logits.softmax(dim=-1, dtype=torch.float32)
draft_probs_buffer[:batch_size, speculative_token_idx, :] = logits.softmax(
dim=-1, dtype=torch.float32)

# NOTE(woosuk): Currently, we ignore most of the sampling parameters in
# generating the draft tokens. We only use the temperature. While this
# could degrade the acceptance rate, it does not affect the distribution
# of the generated tokens after rejection sampling.

# TODO(woosuk): Consider seeds.
q = torch.empty_like(probs)
q = torch.empty_like(draft_probs_buffer[:batch_size,
speculative_token_idx, :])
q.exponential_()
next_token_ids = probs.div_(q).argmax(dim=-1).view(-1)
draft_token_ids_buffer[:batch_size, speculative_token_idx] = \
draft_probs_buffer[:batch_size, speculative_token_idx, :] \
.div_(q) \
.argmax(dim=-1) \
.view(-1)
if not sampling_metadata.all_random:
greedy_token_ids = probs.argmax(dim=-1)
next_token_ids = torch.where(
greedy_token_ids = draft_probs_buffer[:batch_size,
speculative_token_idx, :].argmax(
dim=-1)
draft_token_ids_buffer[:batch_size, speculative_token_idx] = \
torch.where(
is_greedy,
greedy_token_ids,
next_token_ids,
draft_token_ids_buffer[:batch_size, speculative_token_idx],
)
return next_token_ids, probs


@triton.jit
Expand Down
14 changes: 11 additions & 3 deletions vllm/v1/spec_decode/ngram_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
class NgramProposer:

def __init__(self, vllm_config: VllmConfig):
self._draft_token_ids = None
# Minimum length of the n-gram to match.
self.min_n = vllm_config.speculative_config.prompt_lookup_min
# Maximum length of the n-gram to match.
Expand All @@ -22,10 +23,16 @@ def __init__(self, vllm_config: VllmConfig):
# This usually takes less than 1 second.
self.propose(np.zeros(1024, dtype=np.int32))

def get_draft_token_ids(self):
return self._draft_token_ids

def get_draft_probs(self):
return None

def propose(
self,
context_token_ids: np.ndarray,
) -> Optional[np.ndarray]:
):
"""Proposes the next sequence of tokens based on n-gram pattern
matching in the context. The function finds matches of the last n
tokens in the previous context, and returns k tokens that followed
Expand Down Expand Up @@ -54,8 +61,9 @@ def propose(
for n in range(self.max_n, self.min_n - 1, -1):
result = _find_subarray_kmp(context_token_ids, n, self.k)
if result is not None:
return result
return None
self._draft_token_ids = result
return
return

def load_model(self, *args, **kwargs):
# No model to load.
Expand Down
13 changes: 5 additions & 8 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,7 @@ def execute_model(
target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler(
spec_decode_metadata,
None, # draft_probs
self.drafter.get_draft_probs(),
target_logits,
bonus_token_ids,
sampling_metadata,
Expand Down Expand Up @@ -1217,7 +1217,7 @@ def execute_model(
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = attn_metadata.slot_mapping[token_indices]

draft_token_ids, draft_probs = self.drafter.propose(
self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
Expand All @@ -1227,10 +1227,7 @@ def execute_model(
block_table=attn_metadata.block_table,
sampling_metadata=sampling_metadata,
)
spec_token_ids = draft_token_ids.tolist()
# TODO(woosuk): Cache draft_probs and use it for rejection sampling
# in the next step.
del draft_probs
spec_token_ids = self.drafter.get_draft_token_ids().tolist()

return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
Expand Down Expand Up @@ -1265,8 +1262,8 @@ def generate_draft_token_ids(
start_idx = self.input_batch.num_tokens_no_spec[i]
end_idx = start_idx + num_sampled_ids
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
drafter_output = self.drafter.propose(
self.input_batch.token_ids_cpu[i, :end_idx])
self.drafter.propose(self.input_batch.token_ids_cpu[i, :end_idx])
drafter_output = self.drafter.get_draft_token_ids()
if drafter_output is None or len(drafter_output) == 0:
draft_token_ids.append([])
else:
Expand Down