Skip to content

Commit 7ca19b0

Browse files
committed
run eagle with full cudagraph
Signed-off-by: qizixi <qizixi@meta.com>
1 parent 61f4fc5 commit 7ca19b0

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

examples/offline_inference/eagle.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def parse_args():
4848
parser.add_argument("--enable_chunked_prefill", action="store_true")
4949
parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
5050
parser.add_argument("--temp", type=float, default=0)
51+
parser.add_argument("--compilation_config", type=str, default="")
5152
return parser.parse_args()
5253

5354

@@ -94,6 +95,9 @@ def main():
9495
"max_model_len": max_model_len,
9596
},
9697
disable_log_stats=False,
98+
compilation_config=(
99+
json.loads(args.compilation_config) if args.compilation_config else None
100+
),
97101
)
98102

99103
sampling_params = SamplingParams(temperature=args.temp, max_tokens=256)

vllm/v1/spec_decode/eagle.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import Any, Optional
4+
35
import torch
46
import torch.nn as nn
57

@@ -369,8 +371,10 @@ def load_model(self, target_model: nn.Module) -> None:
369371
def dummy_run(
370372
self,
371373
num_tokens: int,
374+
attn_metadata: Optional[dict[str, Any]],
372375
) -> None:
373-
with set_forward_context(None, self.vllm_config,
376+
with set_forward_context(attn_metadata,
377+
self.vllm_config,
374378
num_tokens=num_tokens):
375379
self.model(
376380
self.input_ids[:num_tokens],

vllm/v1/worker/gpu_model_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1860,7 +1860,7 @@ def maybe_randomize_inputs(self, input_ids: torch.Tensor):
18601860
Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
18611861
This is to help balance expert-selection
18621862
- during profile_run
1863-
- during DP rank dummy run
1863+
- during DP rank dummy run
18641864
"""
18651865
dp_size = self.vllm_config.parallel_config.data_parallel_size
18661866
randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1
@@ -1982,7 +1982,7 @@ def _dummy_run(
19821982

19831983
if self.speculative_config and self.speculative_config.use_eagle():
19841984
assert isinstance(self.drafter, EagleProposer)
1985-
self.drafter.dummy_run(num_tokens)
1985+
self.drafter.dummy_run(num_tokens, attn_metadata)
19861986

19871987
logit_indices = np.cumsum(num_scheduled_tokens) - 1
19881988
return hidden_states, hidden_states[logit_indices]

0 commit comments

Comments
 (0)