diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 156f5764e8d..79c470d1216 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -12,11 +12,13 @@ from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.triton_utils import triton from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata, FlashAttentionMetadata) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel +from vllm.v1.spec_decode.utils import (advance_state_kernel, + prepare_eagle_input_kernel) logger = init_logger(__name__) @@ -75,6 +77,15 @@ def __init__( device=device, dtype=torch.int32) + # Used to store precomputed values from load_model() + # so they can be used in propose() + self.last_token_indices = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=device) + self.seq_lens = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=device) + def propose( self, # [num_tokens] @@ -92,40 +103,21 @@ def propose( # [batch_size, max_num_blocks_per_req] block_table: torch.Tensor, sampling_metadata: SamplingMetadata, + num_tokens: int, + max_num_tokens: int, + max_seq_len: int, ) -> torch.Tensor: - num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] - last_token_indices = cu_num_tokens[1:] - 1 - - if self.method == "eagle3": - assert isinstance(self.model, Eagle3LlamaForCausalLM) - target_hidden_states = self.model.combine_hidden_states( - target_hidden_states) - assert target_hidden_states.shape[-1] == self.hidden_size - - # Shift the input ids by one token. - # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[:num_tokens - 1] = target_token_ids[1:] - # Replace the last token with the next token. - # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - self.input_ids[last_token_indices] = next_token_ids - - # FA requires seq_len to have dtype int32. - seq_lens = (target_positions[last_token_indices] + 1).int() if self.method in ["eagle", "eagle3"]: - # FIXME(woosuk): The below two ops cause synchronization. Optimize. - max_seq_len = seq_lens.max().item() - max_num_tokens = (cu_num_tokens[1:] - - cu_num_tokens[:-1]).max().item() attn_metadata = FlashAttentionMetadata( num_actual_tokens=num_tokens, max_query_len=max_num_tokens, query_start_loc=cu_num_tokens, max_seq_len=max_seq_len, - seq_lens=seq_lens, + seq_lens=self.seq_lens, block_table=block_table, - slot_mapping=target_slot_mapping, + slot_mapping=target_slot_mapping[:num_tokens], # TODO(woosuk): Support cascade attention. use_cascade=False, common_prefix_len=0, @@ -134,15 +126,12 @@ def propose( suffix_kv_lens=None, ) elif self.method == "deepseek_mtp": - query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] - max_query_len = query_lens.max().item() - common_attn_metadata = CommonAttentionMetadata( query_start_loc=cu_num_tokens, - seq_lens=seq_lens, + seq_lens=self.seq_lens, num_reqs=batch_size, num_actual_tokens=num_tokens, - max_query_len=max_query_len, + max_query_len=self.max_num_tokens, ) assert self.runner is not None @@ -165,9 +154,6 @@ def propose( num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) else: num_input_tokens = num_tokens - # copy inputs to buffer for cudagraph - self.positions[:num_tokens] = target_positions - self.hidden_states[:num_tokens] = target_hidden_states with set_forward_context(per_layer_attn_metadata, self.vllm_config, @@ -181,7 +167,7 @@ def propose( last_hidden_states = ret_hidden_states else: last_hidden_states, hidden_states = ret_hidden_states - sample_hidden_states = last_hidden_states[last_token_indices] + sample_hidden_states = last_hidden_states[self.last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) draft_token_ids = logits.argmax(dim=-1) @@ -197,8 +183,8 @@ def propose( # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] - positions = target_positions[last_token_indices] - hidden_states = hidden_states[last_token_indices] + positions = target_positions[self.last_token_indices] + hidden_states = hidden_states[self.last_token_indices] if self.use_cuda_graph and \ batch_size <= self.cudagraph_batch_sizes[-1]: input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) @@ -208,52 +194,12 @@ def propose( attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange[:batch_size + 1] for _ in range(self.num_speculative_tokens - 1): - # Update the inputs. - # cast to int32 is crucial when eagle model is compiled. - # tensor.argmax() returns int64 by default. - input_ids = draft_token_ids_list[-1].int() - positions += 1 - - # NOTE(woosuk): We should handle the case where the draft model - # generates tokens beyond the max model length. Since it is complex - # to remove such requests from the batch, we keep them in the batch - # but adjust the position ids and slot mappings to avoid the - # out-of-range access during the model execution. The draft tokens - # generated with this adjustment should be ignored. - exceeds_max_model_len = positions >= self.max_model_len - # Mask out the position ids that exceed the max model length. - # Otherwise, we may get out-of-range error in RoPE. - clamped_positions = torch.where(exceeds_max_model_len, 0, - positions) - - # Increment the sequence lengths. - attn_metadata.max_seq_len += 1 - attn_metadata.seq_lens += 1 - # Consider max model length. - attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, - self.max_model_len) - # For the requests that exceed the max model length, we set the - # sequence length to 1 to minimize their overheads in attention. - attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) - - # Compute the slot mapping. - block_numbers = clamped_positions // self.block_size - block_ids = block_table.gather(dim=1, - index=block_numbers.view(-1, 1)) - block_ids = block_ids.view(-1) - attn_metadata.slot_mapping = (block_ids * self.block_size + - clamped_positions % self.block_size) - # Mask out the slot mappings that exceed the max model length. - # Otherwise, the KV cache will be inadvertently updated with the - # padding tokens. - attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len, - PADDING_SLOT_ID) - # copy inputs to buffer for cudagraph - self.input_ids[:batch_size] = input_ids - self.positions[:batch_size] = clamped_positions - self.hidden_states[:batch_size] = hidden_states + self.advance_speculative_state(draft_token_ids_list[-1], positions, + hidden_states, attn_metadata, + batch_size) + # copy inputs to buffer for cudagraph # Run the model. with set_forward_context(per_layer_attn_metadata, self.vllm_config, @@ -275,6 +221,58 @@ def propose( draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids + def advance_speculative_state(self, draft_token_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + batch_size: int): + """ + Advances the speculative decoding state and metadata by one step + + Parameters: + ---------- + draft_token_ids (torch.Tensor): Token IDs generated by the draft model + positions (torch.Tensor): Position indices for the draft tokens + hidden_states (torch.Tensor): Corresponding hidden states for the tokens + attn_metadata (FlashAttentionMetadata): FlashAttention metadata + batch_size (int): Number of sequences to update. + """ + + # Calculate number of thread blocks + grid = lambda meta: (triton.cdiv(batch_size, meta['BLOCK_SIZE']), ) + attn_metadata.slot_mapping = torch.empty_like(positions) + advance_state_kernel[grid]( + # === Input tensors === + draft_token_ids, + positions, + + # === Model input buffers to be updated === + self.input_ids[:batch_size], + self.positions[:batch_size], + + # === Metadata tensors === + attn_metadata.seq_lens, + attn_metadata.block_table, + attn_metadata.slot_mapping, + + # === Scalar configuration === + self.max_model_len, + self.block_size, + self.max_model_len // self.block_size, + + # === Execution control === + batch_size, + BLOCK_SIZE=1024, + PADDING_SLOT_ID=PADDING_SLOT_ID) + + self.hidden_states[:batch_size] = hidden_states + + # Increment the sequence lengths. + attn_metadata.max_seq_len += 1 + # Consider max model length. + attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, + self.max_model_len) + @staticmethod def prepare_inputs( # [batch_size + 1] @@ -301,7 +299,7 @@ def prepare_inputs( # [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] cu_num_tokens = torch.zeros_like(cu_target_query_lens) torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) - token_indices = torch.empty( + token_indices = torch.zeros( num_tokens, dtype=torch.int32, device=cu_target_query_lens.device, @@ -316,6 +314,54 @@ def prepare_inputs( ) return cu_num_tokens, token_indices + def load_inputs(self, target_token_ids: torch.Tensor, + target_positions: torch.Tensor, + target_hidden_states: torch.Tensor, + next_token_ids_gpu: torch.Tensor, + cu_num_tokens: torch.Tensor, num_scheduled_tokens: int): + """ + Loads token ids, positions, etc. into the eagle model + + Logic moved from EagleProposer.propose() to here + + Parameters: + ---------- + target_token_ids (torch.Tensor): Draft-step token IDs + target_positions (torch.Tensor): Token Position indices + target_hidden_states (torch.Tensor): Token hidden states + next_token_ids_gpu (torch.Tensor): Sampled final token IDs + cu_num_tokens (torch.Tensor): Cumulative tokens from prepare_inputs() + num_scheduled_tokens (int): Total number of tokens scheduled + """ + + self.last_token_indices = cu_num_tokens[1:] - 1 + + # FA requires seq_len to have dtype int32. + self.seq_lens = (target_positions[self.last_token_indices] + 1).int() + + if self.method == "eagle3": + assert isinstance(self.model, Eagle3LlamaForCausalLM) + target_hidden_states = self.model.combine_hidden_states( + target_hidden_states) + assert target_hidden_states.shape[-1] == self.hidden_size + + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + self.input_ids[:num_scheduled_tokens - + 1] = target_token_ids[:num_scheduled_tokens][1:] + + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] + self.input_ids[self.last_token_indices] = next_token_ids_gpu + + # copy inputs to buffer for cudagraph + self.positions[: + num_scheduled_tokens] = target_positions[: + num_scheduled_tokens] + self.hidden_states[: + num_scheduled_tokens] = target_hidden_states[: + num_scheduled_tokens] + def load_model(self, target_model: nn.Module) -> None: draft_model_config = \ self.vllm_config.speculative_config.draft_model_config diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index 5c37333cebc..98851698d2e 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -44,3 +44,82 @@ def prepare_eagle_input_kernel( index_start + offset, mask=offset < num_tokens, ) + + +@triton.jit +def advance_state_kernel( + draft_token_ids_ptr, + positions_ptr, + + # === Model input buffers to be updated === + model_input_ids_ptr, + model_positions_ptr, + + # === Metadata tensors === + seq_lens_ptr, + block_table_ptr, + slot_mapping_ptr, + + # === Scalar configuration === + model_max_len: int, + model_block_size: int, + model_block_stride: int, + + # === Execution control === + n_elements: int, + BLOCK_SIZE: tl.constexpr, + PADDING_SLOT_ID: tl.constexpr, +): + # Triton kernel to perform draft model state advancement. + + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + draft_token_list_last = tl.load(draft_token_ids_ptr + offsets, mask=mask) + position = tl.load(positions_ptr + offsets, mask=mask) + seq_lens = tl.load(seq_lens_ptr + offsets, mask=mask) + + # Update the inputs. + # cast to int32 is crucial when eagle model is compiled. + # tensor.argmax() returns int64 by default. + input_id = draft_token_list_last.cast(tl.int32) + position = position + 1 + + # NOTE(woosuk): We should handle the case where the draft model + # generates tokens beyond the max model length. Since it is complex + # to remove such requests from the batch, we keep them in the batch + # but adjust the position ids and slot mappings to avoid the + # out-of-range access during the model execution. The draft tokens + # generated with this adjustment should be ignored. + exceeds_max_model_len = position >= model_max_len + # Mask out the position ids that exceed the max model length. + # Otherwise, we may get out-of-range error in RoPE. + clamped_position = tl.where(exceeds_max_model_len, 0, position) + + # For the requests that exceed the max model length, we set the + # sequence length to 1 to minimize their overheads in attention. + seq_lens += 1 + seq_lens = tl.where(exceeds_max_model_len, 1, seq_lens) + + block_numbers = clamped_position // model_block_size + block_offsets = clamped_position % model_block_size + + block_ids = tl.load(block_table_ptr + model_block_stride * offsets + + block_numbers, + mask=mask) + + # Compute slot mapping + slot_mapping = block_ids * model_block_size + block_offsets + + # Mask out the slot mappings that exceed the max model length. + # Otherwise, the KV cache will be inadvertently updated with the + # padding tokens. + slot_mapping = tl.where(exceeds_max_model_len, PADDING_SLOT_ID, + slot_mapping) + + tl.store(model_input_ids_ptr + offsets, input_id, mask=mask) + tl.store(positions_ptr + offsets, position, mask=mask) + tl.store(model_positions_ptr + offsets, clamped_position, mask=mask) + tl.store(seq_lens_ptr + offsets, seq_lens, mask=mask) + tl.store(slot_mapping_ptr + offsets, slot_mapping, mask=mask) \ No newline at end of file diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index ca2bfe83174..b96f6005f47 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -52,8 +52,11 @@ def num_tokens(self) -> int: def get_token_id(self, idx: int) -> int: if idx < self.num_prompt_tokens: return self.prompt_token_ids[idx] - else: + elif idx - self.num_prompt_tokens < len(self.output_token_ids): return self.output_token_ids[idx - self.num_prompt_tokens] + # This is now precomputed, so we create a fallback if the idx is invalid + else: + return -1 # Invalid token id class InputBatch: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 40639fdf243..8f8670b6af4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -41,9 +41,8 @@ from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, LazyLoader, async_tensor_h2d, cdiv, - check_use_alibi, get_dtype_size, - is_pin_memory_available) + GiB_bytes, LazyLoader, cdiv, check_use_alibi, + get_dtype_size, is_pin_memory_available) from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) @@ -238,6 +237,12 @@ def __init__( self.slot_mapping = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) + self.remaining_req_indices = torch.zeros(self.max_num_reqs, + dtype=torch.int32, + device=self.device) + self.backup_next_token_ids = torch.zeros(self.max_num_reqs, + dtype=torch.int32, + device=self.device) # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -300,6 +305,20 @@ def __init__( device="cpu", pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() + self.remaining_req_indices_cpu = torch.zeros( + self.max_num_reqs, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.remaining_req_indices_np = self.remaining_req_indices_cpu.numpy() + self.remaining_req_count = 0 + self.discard_req_np = np.zeros(self.max_num_reqs) + self.backup_next_token_ids_cpu = torch.zeros( + self.max_num_reqs, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.backup_next_token_ids_np = self.backup_next_token_ids_cpu.numpy() # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it @@ -658,10 +677,40 @@ def _prepare_inputs( self.query_start_loc_np[0] = 0 self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens + # Prepare seq_len and num_token for eagle metadata self.seq_lens_np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens) + num_tokens = [ + self.requests[r].num_tokens for r in self.input_batch.req_ids + ] + num_tokens_np = np.array(num_tokens, dtype=np.int32) + + # Record the index of requests that should not be sampled, + # so that we could clear the sampled tokens before returning + self.discard_req_np[:num_reqs] = \ + self.seq_lens_np[:num_reqs] < num_tokens_np + + # Also record indices of requests that should be sampled + self.remaining_req_count = np.count_nonzero( + self.discard_req_np[:num_reqs] == 0) + self.remaining_req_indices_np[:self.remaining_req_count] = np.nonzero( + self.discard_req_np == 0)[0][:self.remaining_req_count] + + self.remaining_req_indices[:self.remaining_req_count].copy_( + self.remaining_req_indices_cpu[:self.remaining_req_count], + non_blocking=True) + + # Precompute get_token_id for when there is no valid next token + self.backup_next_token_ids_np[:num_reqs] = np.array([ + self.requests[self.input_batch.req_ids[i]].get_token_id( + self.seq_lens_np[i]) for i in range(num_reqs) + ]) + + self.backup_next_token_ids[:num_reqs].copy_( + self.backup_next_token_ids_cpu[:num_reqs], non_blocking=True) + # Copy the tensors to the GPU. self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) @@ -1435,23 +1484,18 @@ def execute_model( if envs.VLLM_COMPUTE_NANS_IN_LOGITS: num_nans_in_logits = self._get_nans_in_logits(logits) - # TODO(woosuk): The following loop can be slow since it iterates over - # the requests one by one. Optimize. - discard_sampled_tokens_req_indices = [] - for i, req_id in enumerate(self.input_batch.req_ids): - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - if seq_len < req_state.num_tokens: - # Ignore the sampled token for partial prefills. - # Rewind the generator state as if the token was not sampled. - # This relies on cuda-specific torch-internal impl details - generator = self.input_batch.generators.get(i) - if generator is not None: - generator.set_offset(generator.get_offset() - 4) - # Record the index of the request that should not be sampled, - # so that we could clear the sampled tokens before returning. - discard_sampled_tokens_req_indices.append(i) + num_reqs = self.input_batch.num_reqs + + discard_sampled_tokens_req_indices = np.nonzero( + self.discard_req_np[:num_reqs])[0] + + for i in discard_sampled_tokens_req_indices: + # Ignore the sampled token for partial prefills. + # Rewind the generator state as if the token was not sampled. + # This relies on cuda-specific torch-internal impl details + gen = self.input_batch.generators.get(int(i)) + if gen is not None: + gen.set_offset(gen.get_offset() - 4) # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. @@ -1468,18 +1512,12 @@ def execute_model( # Get the valid generated tokens. sampled_token_ids = sampler_output.sampled_token_ids max_gen_len = sampled_token_ids.shape[-1] - if max_gen_len == 1: - # No spec decode tokens. - valid_sampled_token_ids = sampled_token_ids.tolist() - else: - # Includes spec decode tokens. - valid_sampled_token_ids = self.rejection_sampler.parse_output( - sampled_token_ids, - self.input_batch.vocab_size, - ) - # Mask out the sampled tokens that should not be sampled. - for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[i].clear() + + if not self.speculative_config or not self.speculative_config.use_eagle( + ): + valid_sampled_token_ids = self.get_valid_sampled_token_ids( + max_gen_len, sampled_token_ids, + discard_sampled_tokens_req_indices) if not self.speculative_config: # Speculative decoding is not enabled. @@ -1511,24 +1549,39 @@ def execute_model( ) elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) - # TODO(woosuk): Refactor the loop. - next_token_ids: list[int] = [] - for i, token_ids in enumerate(valid_sampled_token_ids): - if token_ids: - # Common case. - next_token_id = token_ids[-1] - else: - # Partial prefill (rare case). - # Get the next token id from the request state. - req_id = self.input_batch.req_ids[i] - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - next_token_id = req_state.get_token_id(seq_len) - next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.device) + + # Get all sampled tokens from valid requests + valid_sampled_token_ids_gpu = sampled_token_ids[ + self.remaining_req_indices[:self.remaining_req_count]] + + # Generate a mask for all valid tokens within those requests + if max_gen_len == 1: + valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, + dtype=torch.bool) + else: + valid_mask = ((valid_sampled_token_ids_gpu != -1) & + (valid_sampled_token_ids_gpu + < self.input_batch.vocab_size)) + + # Count valid tokens in each request + valid_sampled_count = valid_mask.sum(dim=1) + + batch = valid_sampled_token_ids_gpu.shape[0] + + # Get the rightmost valid index per row + last_valid_indices = valid_sampled_count - 1 + + # Get last valid token from each row + # (assume undefined state where there is no valid token) + selected_tokens = torch.gather( + valid_sampled_token_ids_gpu, 1, + last_valid_indices.unsqueeze(1)).squeeze(1) + + # Use last token if valid, pre-computed backup if not + next_token_ids_gpu = torch.where( + last_valid_indices != -1, selected_tokens, + self.backup_next_token_ids[:batch]) + # At this moment, we assume all eagle layers belong to the same KV # cache group, thus using the same attention metadata. eagle_attn_metadata = attn_metadata[ @@ -1553,22 +1606,21 @@ def execute_model( target_slot_mapping = eagle_attn_metadata.slot_mapping cu_num_tokens = eagle_attn_metadata.query_start_loc else: - # TODO(woosuk): Refactor this. - num_draft_tokens = spec_decode_metadata.num_draft_tokens - num_rejected_tokens = [ - n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 - for i, n in enumerate(num_draft_tokens) - ] - num_rejected_tokens_tensor = async_tensor_h2d( - num_rejected_tokens, - dtype=torch.int32, - target_device=self.device, - pin_memory=True) - num_tokens = num_scheduled_tokens - sum(num_rejected_tokens) + # Recompute num_draft_tokens from cumsum + num_draft_tokens_gpu = torch.cat([ + spec_decode_metadata.cu_num_draft_tokens[0:1], + spec_decode_metadata.cu_num_draft_tokens[1:] - + spec_decode_metadata.cu_num_draft_tokens[:-1] + ]) + + num_rejected_tokens_gpu = torch.where( + num_draft_tokens_gpu > 0, + num_draft_tokens_gpu + 1 - valid_sampled_count, + torch.zeros_like(num_draft_tokens_gpu)) cu_num_tokens, token_indices = self.drafter.prepare_inputs( eagle_attn_metadata.query_start_loc, - num_rejected_tokens_tensor, - num_tokens, + num_rejected_tokens_gpu, + num_scheduled_tokens, ) target_token_ids = self.input_ids[token_indices] target_positions = positions[token_indices] @@ -1579,16 +1631,50 @@ def execute_model( target_hidden_states = hidden_states[token_indices] target_slot_mapping = eagle_attn_metadata.slot_mapping[ token_indices] + + # load token ids, positions, etc. into the eagle model + self.drafter.load_inputs(target_token_ids, target_positions, + target_hidden_states, next_token_ids_gpu, + cu_num_tokens, num_scheduled_tokens) + + if self.speculative_config and self.speculative_config.use_eagle(): + valid_sampled_token_ids = self.get_valid_sampled_token_ids( + max_gen_len, sampled_token_ids, + discard_sampled_tokens_req_indices) + + if spec_decode_metadata is not None: + num_draft_tokens = spec_decode_metadata.num_draft_tokens + num_rejected_tokens_np = [ + n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + else: + num_rejected_tokens_np = np.zeros(len( + self.input_batch.req_ids)) + + num_tokens = num_scheduled_tokens - int( + sum(num_rejected_tokens_np)) + + max_seq_len = int( + (self.seq_lens_np[:num_reqs] - num_rejected_tokens_np).max()) + max_num_tokens = int( + (self.seq_lens_np[:num_reqs] - + self.input_batch.num_computed_tokens_cpu[:num_reqs] - + num_rejected_tokens_np + ).max()) if spec_decode_metadata else max_seq_len + draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, target_slot_mapping=target_slot_mapping, - next_token_ids=next_token_ids, + next_token_ids=next_token_ids_gpu, cu_num_tokens=cu_num_tokens, block_table=block_table, sampling_metadata=sampling_metadata, - ) + num_tokens=num_tokens, + max_num_tokens=max_num_tokens, + max_seq_len=max_seq_len) spec_token_ids = draft_token_ids.tolist() # Clear KVConnector state after all KVs are generated. @@ -1608,6 +1694,34 @@ def execute_model( num_nans_in_logits=num_nans_in_logits, ) + def get_valid_sampled_token_ids( + self, max_gen_len: int, sampled_token_ids: torch.Tensor, + discard_sampled_tokens_req_indices: np.ndarray) -> list[list[int]]: + """ + Returns valid sampled tokens in a list of lists + + Parameters: + ---------- + - max_gen_len: Maximum length of the generated tokens + - sampled_token_ids: Tensor of sampled token IDs + - discard_sampled_tokens_req_indices: Indices that should not be sampled + """ + + if max_gen_len == 1: + # No spec decode tokens. + valid_sampled_token_ids = sampled_token_ids.tolist() + else: + # Includes spec decode tokens. + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() + + return valid_sampled_token_ids + def kv_connector_no_forward( self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: # KV send/recv even if no work to do.