From c8e31e6492983070d9a3937bee67df9d6ab279f6 Mon Sep 17 00:00:00 2001 From: Andy Lo Date: Fri, 4 Jul 2025 00:14:44 +0000 Subject: [PATCH 1/4] Spec decode with probs Signed-off-by: Andy Lo --- tests/v1/spec_decode/test_eagle.py | 45 ++-- tests/v1/spec_decode/test_scheduling.py | 253 +++++++++++++++++++++++ tests/v1/worker/test_gpu_model_runner.py | 2 +- vllm/v1/spec_decode/eagle.py | 31 +-- vllm/v1/worker/gpu_input_batch.py | 1 + vllm/v1/worker/gpu_model_runner.py | 52 ++++- 6 files changed, 348 insertions(+), 36 deletions(-) create mode 100644 tests/v1/spec_decode/test_scheduling.py diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 5efab2c1440..4c4e37329ad 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -294,34 +294,49 @@ def create_deterministic_logits(token_ids): block_table = torch.randint(0, 10, (batch_size, 10), device=device) sampling_metadata = mock.MagicMock() + # Simulate mixed greedy and non-greedy requests + sampling_metadata.all_greedy = False + sampling_metadata.temperature = torch.tensor([-1, 0.7], device=device) # Call the method under test - result = proposer.propose(target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - target_slot_mapping=target_slot_mapping, - next_token_ids=next_token_ids, - cu_num_tokens=cu_num_tokens, - block_table=block_table, - sampling_metadata=sampling_metadata) - - assert result.shape == (batch_size, num_speculative_tokens) + result, result_probs = proposer.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + target_slot_mapping=target_slot_mapping, + next_token_ids=next_token_ids, + cu_num_tokens=cu_num_tokens, + block_table=block_table, + sampling_metadata=sampling_metadata) + + assert len(result) == batch_size + assert len(result_probs) == batch_size + assert all(len(tokens) == num_speculative_tokens for tokens in result) + assert all(r.shape == (num_speculative_tokens, vocab_size) + for r in result_probs) # Create expected tokens based on our token pattern if num_speculative_tokens == 1: # Example for num_speculative_tokens=1: # [[42], [60]] - expected_tokens = torch.tensor( - [[base_token_ids[0]], [base_token_ids[1]]], device=device) + expected_tokens = torch.tensor([[base_token_ids[0]], + [base_token_ids[1]]]) + expected_probs = torch.zeros((batch_size, 1, vocab_size), + device=device) + for i, token_id in enumerate(base_token_ids): + expected_probs[i, 0, token_id] = 1.0 else: # Example for num_speculative_tokens=3: # [[42, 43, 44], [60, 61, 62]] expected_tokens = torch.zeros((batch_size, num_speculative_tokens), - dtype=torch.int64, - device=device) + dtype=torch.int64) + expected_probs = torch.zeros( + (batch_size, num_speculative_tokens, vocab_size), device=device) for i in range(batch_size): for j in range(num_speculative_tokens): expected_tokens[i, j] = base_token_ids[i] + j + expected_probs[i, j, base_token_ids[i] + j] = 1.0 # Verify all tokens match our expectations - assert torch.equal(result, expected_tokens) + assert torch.equal(torch.tensor(result), expected_tokens) + torch.testing.assert_close(torch.stack(result_probs), expected_probs) diff --git a/tests/v1/spec_decode/test_scheduling.py b/tests/v1/spec_decode/test_scheduling.py new file mode 100644 index 00000000000..ecde5b082c1 --- /dev/null +++ b/tests/v1/spec_decode/test_scheduling.py @@ -0,0 +1,253 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import tempfile + +import pytest +import torch + +from tests.v1.worker.test_gpu_model_runner import _schedule_new_request +from vllm.config import VllmConfig +from vllm.distributed import (cleanup_dist_env_and_memory, + init_distributed_environment, + initialize_model_parallel) +from vllm.engine.arg_utils import EngineArgs +from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput +from vllm.v1.engine.core import get_kv_cache_config +from vllm.v1.worker.gpu_model_runner import GPUModelRunner + +model_dir = "meta-llama/Llama-3.1-8B-Instruct" +eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" + + +@pytest.fixture() +def should_do_global_cleanup_after_test(request) -> bool: + # So we can share the DraftModelProposer between tests + return False + + +@pytest.fixture(scope="class") +def monkeyclass(): + with pytest.MonkeyPatch.context() as mp: + yield mp + + +@pytest.fixture(scope="class") +def spec_decode_vllm_config_and_env_setup(monkeyclass: pytest.MonkeyPatch): + with monkeyclass.context() as m: + m.setenv("VLLM_USE_V1", "1") + vllm_config = EngineArgs(model=model_dir, + max_model_len=256, + cuda_graph_sizes=[1, 2, 4], + gpu_memory_utilization=0.8, + speculative_config={ + "model": eagle_dir, + "method": "eagle", + "num_speculative_tokens": 2, + }).create_engine_config() + temp_file = tempfile.mkstemp()[1] + init_distributed_environment( + world_size=1, + rank=0, + distributed_init_method=f"file://{temp_file}", + local_rank=0, + backend="nccl", + ) + initialize_model_parallel(1, 1) + yield vllm_config + cleanup_dist_env_and_memory() + + +@pytest.fixture(scope="class") +def mock_spec_decode_model_runner( + spec_decode_vllm_config_and_env_setup: VllmConfig): + model_runner = GPUModelRunner(spec_decode_vllm_config_and_env_setup, + torch.device("cuda")) + model_runner.load_model() + kv_cache_spec = model_runner.get_kv_cache_spec() + + kv_cache_config = get_kv_cache_config( + spec_decode_vllm_config_and_env_setup, kv_cache_spec, 1024**3) # 1GB + model_runner.initialize_kv_cache(kv_cache_config) + yield model_runner + + +class TestSpecDecodeScheduling: + + def test_spec_decode_partial_scheduling( + self, mock_spec_decode_model_runner: GPUModelRunner): + """Make sure we don't crash when the scheduler schedules only a subset + of the requests. + + Four iterations: + 1. Schedule both req1 (w/ 0 draft) and req2 (w/ 0 draft) + 2. Schedule only req1 (w/ 1 draft) + 3. Schedule both req1 (w/ 1 draft) and req2 (w/ 2 draft) + 4. Terminate req1 and req2 + """ + # Schedule both req1 and req2 on the first iteration + scheduler_output = _schedule_new_request("req1", "req2") + mock_spec_decode_model_runner.execute_model(scheduler_output) + + # Only schedule req1 on the second iteration + cached_req_data = CachedRequestData( + req_ids=["req1"], + resumed_from_preemption=[False], + new_token_ids=[[3]], + new_block_ids=[([], )], + num_computed_tokens=[3], + ) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=cached_req_data, + num_scheduled_tokens={"req1": 2}, + total_num_scheduled_tokens=2, + scheduled_spec_decode_tokens={"req1": [1001]}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[0], + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + mock_spec_decode_model_runner.execute_model(scheduler_output) + + # Schedule both req1 and req2 on the third iteration + cached_req_data = CachedRequestData( + req_ids=["req1", "req2"], + resumed_from_preemption=[False, False], + new_token_ids=[[10], [11]], + new_block_ids=[([], ), ([], )], + num_computed_tokens=[4, 3], + ) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=cached_req_data, + num_scheduled_tokens={ + "req1": 2, + "req2": 3 + }, + total_num_scheduled_tokens=5, + scheduled_spec_decode_tokens={ + "req1": [1001], + "req2": [2001, 2002] + }, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[0], + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + mock_spec_decode_model_runner.execute_model(scheduler_output) + + # Terminate both req1 and req2 + cached_req_data = CachedRequestData( + req_ids=[], + resumed_from_preemption=[], + new_token_ids=[], + new_block_ids=[], + num_computed_tokens=[], + ) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=cached_req_data, + num_scheduled_tokens={}, + total_num_scheduled_tokens=0, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[0], + finished_req_ids={"req1", "req2"}, + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + mock_spec_decode_model_runner.execute_model(scheduler_output) + + def test_spec_decode_preemption_scheduling( + self, mock_spec_decode_model_runner: GPUModelRunner): + """Make sure we don't crash when the scheduler preempts a request. + + Four iterations: + 1. Schedule req1 (w/ 0 draft) and req2 (w/ 0 draft) + 2. Schedule req1 (w/ 1 draft) and preempt req2 + 3. Schedule req1 (w/ 1 draft) and resume req2 (w/ 2 draft) + 4. Terminate req1 and req2 + """ + # Schedule both req1 and req2 on the first iteration + scheduler_output = _schedule_new_request("req1", "req2") + mock_spec_decode_model_runner.execute_model(scheduler_output) + + # Only schedule req1 on the second iteration + cached_req_data = CachedRequestData( + req_ids=["req1"], + resumed_from_preemption=[False], + new_token_ids=[[3]], + new_block_ids=[([], )], + num_computed_tokens=[3], + ) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=cached_req_data, + num_scheduled_tokens={"req1": 2}, + total_num_scheduled_tokens=2, + scheduled_spec_decode_tokens={"req1": [1001]}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[0], + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + mock_spec_decode_model_runner.execute_model(scheduler_output) + + # Schedule both req1 and req2 on the third iteration + cached_req_data = CachedRequestData( + req_ids=["req1", "req2"], + resumed_from_preemption=[False, True], + new_token_ids=[[10], [11]], + new_block_ids=[([], ), ([0], )], + num_computed_tokens=[4, 0], + ) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=cached_req_data, + num_scheduled_tokens={ + "req1": 2, + "req2": 6 + }, + total_num_scheduled_tokens=8, + scheduled_spec_decode_tokens={ + "req1": [1001], + "req2": [2001, 2002] + }, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[0], + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + mock_spec_decode_model_runner.execute_model(scheduler_output) + + # Terminate both req1 and req2 + cached_req_data = CachedRequestData( + req_ids=[], + resumed_from_preemption=[], + new_token_ids=[], + new_block_ids=[], + num_computed_tokens=[], + ) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=cached_req_data, + num_scheduled_tokens={}, + total_num_scheduled_tokens=0, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[0], + finished_req_ids={"req1", "req2"}, + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + mock_spec_decode_model_runner.execute_model(scheduler_output) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index d13df553db6..cc037802256 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -139,7 +139,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: total_num_scheduled_tokens=total_num_scheduled_tokens, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[0], finished_req_ids=set(), free_encoder_input_ids=[], structured_output_request_ids={}, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 6661d984a77..06602a2565b 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -92,7 +92,7 @@ def propose( # [batch_size, max_num_blocks_per_req] block_table: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: + ) -> tuple[list[list[int]], list[torch.Tensor]]: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] last_token_indices = cu_num_tokens[1:] - 1 @@ -183,19 +183,26 @@ def propose( last_hidden_states, hidden_states = ret_hidden_states sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) - draft_token_ids = logits.argmax(dim=-1) + draft_token_ids, draft_probs = compute_probs_and_sample_next_token( + logits, sampling_metadata) # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1: - # [batch_size, 1] - return draft_token_ids.view(-1, 1) + # [batch_size, 1] and [batch_size, 1, vocab_size] + return ( + draft_token_ids.view(-1, 1).tolist(), + draft_probs.unsqueeze(1).unbind(0), + ) # TODO: Currently, MTP module released by deepseek only has # one layer. Adapt this code to support multiple layers once # there's a multi-layer MTP module. # Generate the remaining draft tokens. - draft_token_ids_list = [draft_token_ids] + # [num_speculative_tokens, batch_size] + draft_token_ids_list: list[torch.Tensor] = [draft_token_ids] + # [num_speculative_tokens, batch_size, vocab_size] + draft_probs_list: list[torch.Tensor] = [draft_probs] positions = target_positions[last_token_indices] hidden_states = hidden_states[last_token_indices] @@ -268,12 +275,16 @@ def propose( None) # TODO(wenlong): get more than one token for tree attention - draft_token_ids = logits.argmax(dim=-1) + draft_token_ids, draft_probs = compute_probs_and_sample_next_token( + logits, sampling_metadata) draft_token_ids_list.append(draft_token_ids) + draft_probs_list.append(draft_probs) # [batch_size, num_speculative_tokens] - draft_token_ids = torch.stack(draft_token_ids_list, dim=1) - return draft_token_ids + draft_token_ids = torch.stack(draft_token_ids_list, dim=1).tolist() + # [batch_size, num_speculative_tokens, vocab_size] + draft_probs_list = torch.stack(draft_probs_list, dim=1).unbind(0) + return draft_token_ids, draft_probs_list @staticmethod def prepare_inputs( @@ -398,10 +409,6 @@ def validate_same_kv_cache_group(self, ) == 1, "All eagle layers should belong to the same kv cache group" -# NOTE(woosuk): Currently, the below code is not used and we always use argmax -# to sample the draft tokens. We will use this after we find a way to manage -# the draft prob tensor. -# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details. # FIXME(woosuk): The logic here is duplicated with the main sampling code. # We should refactor this to reuse the same sampling implementation. def compute_probs_and_sample_next_token( diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 1a79d72be0a..b45abb58ce5 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -43,6 +43,7 @@ class CachedRequestState: mrope_position_delta: Optional[int] = None lora_request: Optional[LoRARequest] = None + draft_probs: Optional[torch.Tensor] = None def __post_init__(self): self.num_prompt_tokens = len(self.prompt_token_ids) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5a26e88db1f..4a7207f6e05 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1448,9 +1448,30 @@ def execute_model( # separate storage from the original `logits` tensor. Therefore, # it is safe to update `target_logits` in place. target_logits = logits[spec_decode_metadata.target_logits_indices] + + draft_probs_list: list[torch.Tensor] = [] + has_draft_probs: list[bool] = [] + for i, req_id in enumerate(self.input_batch.req_ids): + draft_length = spec_decode_metadata.num_draft_tokens[i] + if draft_length > 0: + draft_probs = self.requests[req_id].draft_probs + if draft_probs is not None: + # <= since not every draft token is necessarily + # scheduled + assert draft_length <= draft_probs.shape[0] + has_draft_probs.append(True) + # Not every draft token is necessarily scheduled + draft_probs_list.append(draft_probs[:draft_length]) + else: + has_draft_probs.append(False) + assert all(has_draft_probs) or not any(has_draft_probs), ( + "Some requests have draft logits while others do not.") + + draft_probs = (torch.cat(draft_probs_list, dim=0) + if len(draft_probs_list) > 0 else None) output_token_ids = self.rejection_sampler( spec_decode_metadata, - None, # draft_probs + draft_probs, target_logits, bonus_token_ids, sampling_metadata, @@ -1533,9 +1554,9 @@ def execute_model( if not self.speculative_config: # Speculative decoding is not enabled. - spec_token_ids = None + spec_token_ids = spec_probs = None else: - spec_token_ids = self.propose_draft_token_ids( + spec_token_ids, spec_probs = self.propose_draft( scheduler_output, valid_sampled_token_ids, sampling_metadata, @@ -1545,6 +1566,11 @@ def execute_model( spec_decode_metadata, attn_metadata, ) + # Save the draft probs for future use, usually the next step. + if spec_probs is not None: + for i, spec_prob in enumerate(spec_probs): + req_id = self.input_batch.req_ids[i] + self.requests[req_id].draft_probs = spec_prob # Clear KVConnector state after all KVs are generated. if has_kv_transfer_group(): @@ -1565,7 +1591,7 @@ def execute_model( num_nans_in_logits=num_nans_in_logits, ) - def propose_draft_token_ids( + def propose_draft( self, scheduler_output: "SchedulerOutput", sampled_token_ids: list[list[int]], @@ -1575,12 +1601,19 @@ def propose_draft_token_ids( aux_hidden_states: Optional[torch.Tensor], spec_decode_metadata: Optional[SpecDecodeMetadata], attn_metadata: dict[str, Any], - ) -> list[list[int]]: + ) -> tuple[list[list[int]], Optional[list[torch.Tensor]]]: + """Generate the draft for the next step. + + Returns: + - The draft token ids. + - The draft probs (optional). + """ num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": assert isinstance(self.drafter, NgramProposer) spec_token_ids = self.propose_ngram_draft_token_ids( sampled_token_ids) + spec_probs = None elif self.speculative_config.method == "medusa": assert isinstance(self.drafter, MedusaProposer) if sample_hidden_states.shape[0] == len(sampled_token_ids): @@ -1601,6 +1634,7 @@ def propose_draft_token_ids( target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, ) + spec_probs = None elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop. @@ -1673,7 +1707,7 @@ def propose_draft_token_ids( target_hidden_states = hidden_states[token_indices] target_slot_mapping = eagle_attn_metadata.slot_mapping[ token_indices] - draft_token_ids = self.drafter.propose( + spec_token_ids, spec_probs = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, @@ -1683,8 +1717,10 @@ def propose_draft_token_ids( block_table=block_table, sampling_metadata=sampling_metadata, ) - spec_token_ids = draft_token_ids.tolist() - return spec_token_ids + else: + raise ValueError(f"Unsupported speculative decoding method: " + f"{self.speculative_config.method}") + return spec_token_ids, spec_probs def kv_connector_no_forward( self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: From 5937e7b4be7d9eb5821bf83e78b144afb6f54f7a Mon Sep 17 00:00:00 2001 From: Andy Lo Date: Fri, 4 Jul 2025 01:23:38 +0100 Subject: [PATCH 2/4] Update vllm/v1/spec_decode/eagle.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Andy Lo --- vllm/v1/spec_decode/eagle.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 06602a2565b..7e9fefc42b5 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -199,9 +199,9 @@ def propose( # there's a multi-layer MTP module. # Generate the remaining draft tokens. - # [num_speculative_tokens, batch_size] + # Each tensor in the list has shape [batch_size]. draft_token_ids_list: list[torch.Tensor] = [draft_token_ids] - # [num_speculative_tokens, batch_size, vocab_size] + # Each tensor in the list has shape [batch_size, vocab_size]. draft_probs_list: list[torch.Tensor] = [draft_probs] positions = target_positions[last_token_indices] From c2dd56604ce34e37672a8e8b73b6064493e8f53a Mon Sep 17 00:00:00 2001 From: Andy Lo Date: Tue, 8 Jul 2025 12:32:54 +0000 Subject: [PATCH 3/4] Warmup Signed-off-by: Andy Lo --- vllm/v1/worker/gpu_model_runner.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4a7207f6e05..554c05c80bc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2170,10 +2170,10 @@ def _dummy_sampler_run( draft_token_ids, self.device) num_tokens = sum(len(ids) for ids in draft_token_ids) - # draft_probs = torch.randn( - # num_tokens, logits.shape[-1], device=self.device, - # dtype=logits.dtype) - draft_probs = None + draft_probs = torch.randn(num_tokens, + logits.shape[-1], + device=self.device, + dtype=logits.dtype) target_logits = torch.randn(num_tokens, logits.shape[-1], device=self.device, From 20e43fd26658b1f99305862d8741077abe45acd5 Mon Sep 17 00:00:00 2001 From: Andy Lo Date: Wed, 9 Jul 2025 03:36:11 +0000 Subject: [PATCH 4/4] torch compile sampling kernel Signed-off-by: Andy Lo --- vllm/v1/spec_decode/eagle.py | 60 +++++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 25 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 7e9fefc42b5..780ce0ae5e0 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -12,6 +12,7 @@ from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import KVCacheConfig @@ -329,7 +330,7 @@ def prepare_inputs( def load_model(self, target_model: nn.Module) -> None: draft_model_config = \ - self.vllm_config.speculative_config.draft_model_config + self.speculative_config.draft_model_config target_attn_layer_names = set( get_layers_from_vllm_config(self.vllm_config, Attention).keys()) @@ -371,7 +372,7 @@ def load_model(self, target_model: nn.Module) -> None: # share lm_head with the target model if needed # some model definition do not define lm_head explicitly # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM - if self.vllm_config.speculative_config.method != "eagle3" and \ + if self.speculative_config.method != "eagle3" and \ hasattr(target_language_model, "lm_head"): logger.info("Loading EAGLE LM head weights from the target model.") self.model.lm_head = target_language_model.lm_head @@ -383,11 +384,18 @@ def dummy_run( ) -> None: with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): - self.model( + ret_hidden_states = self.model( self.input_ids[:num_tokens], self.positions[:num_tokens], self.hidden_states[:num_tokens], ) + if self.method == "deepseek_mtp": + last_hidden_states = ret_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states + logits = self.model.compute_logits(last_hidden_states, None) + temperature = torch.ones(num_tokens, device=logits.device) + _mixed_sample(logits, temperature) def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: @@ -409,21 +417,14 @@ def validate_same_kv_cache_group(self, ) == 1, "All eagle layers should belong to the same kv cache group" -# FIXME(woosuk): The logic here is duplicated with the main sampling code. -# We should refactor this to reuse the same sampling implementation. -def compute_probs_and_sample_next_token( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> tuple[torch.Tensor, torch.Tensor]: - if sampling_metadata.all_greedy: - # For greedy requests, draft_probs is not used in rejection sampling. - # Therefore, we can just return the logits. - probs = logits - next_token_ids = logits.argmax(dim=-1) - return next_token_ids, probs - - is_greedy = sampling_metadata.temperature == -1 - temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature) +@torch.compile(dynamic=True, + backend=current_platform.simple_compile_backend, + mode="max-autotune-no-cudagraphs") +def _mixed_sample( + logits: torch.Tensor, + temperature: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + is_greedy = temperature == -1 + temperature = torch.where(is_greedy, 1.0, temperature) logits.div_(temperature.view(-1, 1)) probs = logits.softmax(dim=-1, dtype=torch.float32) @@ -435,14 +436,23 @@ def compute_probs_and_sample_next_token( # TODO(woosuk): Consider seeds. q = torch.empty_like(probs) q.exponential_() + q[is_greedy, :] = 1.0 # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs # will be used later for rejection sampling. next_token_ids = probs.div(q).argmax(dim=-1).view(-1) - if not sampling_metadata.all_random: - greedy_token_ids = probs.argmax(dim=-1) - next_token_ids = torch.where( - is_greedy, - greedy_token_ids, - next_token_ids, - ) return next_token_ids, probs + + +# FIXME(woosuk): The logic here is duplicated with the main sampling code. +# We should refactor this to reuse the same sampling implementation. +def compute_probs_and_sample_next_token( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> tuple[torch.Tensor, torch.Tensor]: + if sampling_metadata.all_greedy: + # For greedy requests, draft_probs is not used in rejection sampling. + # Therefore, we can just return the logits. + probs = logits + next_token_ids = logits.argmax(dim=-1) + return next_token_ids, probs + return _mixed_sample(logits, sampling_metadata.temperature)