diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index f4193fdb8bd..e25b27bf487 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -48,6 +48,7 @@ def parse_args(): parser.add_argument("--enable_chunked_prefill", action="store_true") parser.add_argument("--max_num_batched_tokens", type=int, default=2048) parser.add_argument("--temp", type=float, default=0) + parser.add_argument("--compilation_config", type=str, default="") return parser.parse_args() @@ -94,6 +95,9 @@ def main(): "max_model_len": max_model_len, }, disable_log_stats=False, + compilation_config=( + json.loads(args.compilation_config) if args.compilation_config else None + ), ) sampling_params = SamplingParams(temperature=args.temp, max_tokens=256) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 153b67fe571..630f381fed4 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional + import torch import torch.nn as nn @@ -74,6 +76,7 @@ def __init__( 1, device=device, dtype=torch.int32) + self.draft_attn_metadata = None def propose( self, @@ -169,6 +172,13 @@ def propose( self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states + # copy attention metadata for full cudagraph mode + if self.draft_attn_metadata is not None and num_tokens <= self.cudagraph_batch_sizes[-1]: + self.draft_attn_metadata.seq_lens[:attn_metadata.seq_lens.shape[0]].copy_(attn_metadata.seq_lens.clone()) + self.draft_attn_metadata.slot_mapping[:attn_metadata.slot_mapping.shape[0]].copy_(attn_metadata.slot_mapping.clone()) + self.draft_attn_metadata.query_start_loc[:attn_metadata.query_start_loc.shape[0]].copy_(attn_metadata.query_start_loc.clone()) + self.draft_attn_metadata.block_table[:attn_metadata.block_table.shape[0]].copy_(attn_metadata.block_table.clone()) + with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens): @@ -254,6 +264,13 @@ def propose( self.positions[:batch_size] = clamped_positions self.hidden_states[:batch_size] = hidden_states + # copy attention metadata for full cudagraph mode + if self.draft_attn_metadata is not None: + self.draft_attn_metadata.seq_lens[:attn_metadata.seq_lens.shape[0]].copy_(attn_metadata.seq_lens.clone()) + self.draft_attn_metadata.slot_mapping[:attn_metadata.slot_mapping.shape[0]].copy_(attn_metadata.slot_mapping.clone()) + self.draft_attn_metadata.query_start_loc[:attn_metadata.query_start_loc.shape[0]].copy_(attn_metadata.query_start_loc.clone()) + self.draft_attn_metadata.block_table[:attn_metadata.block_table.shape[0]].copy_(attn_metadata.block_table.clone()) + # Run the model. with set_forward_context(per_layer_attn_metadata, self.vllm_config, @@ -369,8 +386,13 @@ def load_model(self, target_model: nn.Module) -> None: def dummy_run( self, num_tokens: int, + attn_metadata: Optional[dict[str, Any]], ) -> None: - with set_forward_context(None, self.vllm_config, + if attn_metadata is not None and self.draft_attn_metadata is None: + attn_metadata[self.attn_layer_names[0]].scheduler_metadata = None + self.draft_attn_metadata = attn_metadata[self.attn_layer_names[0]] # assume only one draft layer + with set_forward_context(attn_metadata, + self.vllm_config, num_tokens=num_tokens): self.model( self.input_ids[:num_tokens], diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 520d8fb186f..300919a5c22 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1860,7 +1860,7 @@ def maybe_randomize_inputs(self, input_ids: torch.Tensor): Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set. This is to help balance expert-selection - during profile_run - - during DP rank dummy run + - during DP rank dummy run """ dp_size = self.vllm_config.parallel_config.data_parallel_size randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1 @@ -1982,7 +1982,7 @@ def _dummy_run( if self.speculative_config and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) - self.drafter.dummy_run(num_tokens) + self.drafter.dummy_run(num_tokens, attn_metadata) logit_indices = np.cumsum(num_scheduled_tokens) - 1 return hidden_states, hidden_states[logit_indices]