Skip to content

[WIP] Run eagle with full cudagraph #20190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/offline_inference/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def parse_args():
parser.add_argument("--enable_chunked_prefill", action="store_true")
parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
parser.add_argument("--temp", type=float, default=0)
parser.add_argument("--compilation_config", type=str, default="")
return parser.parse_args()


Expand Down Expand Up @@ -94,6 +95,9 @@ def main():
"max_model_len": max_model_len,
},
disable_log_stats=False,
compilation_config=(
json.loads(args.compilation_config) if args.compilation_config else None
),
Comment on lines +98 to +100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The direct call to json.loads can cause the script to crash with a json.JSONDecodeError if an invalid JSON string is passed to the --compilation_config argument. Consider adding a try-except block to handle potential parsing errors gracefully.

compilation_config = None
if args.compilation_config:
    try:
        compilation_config = json.loads(args.compilation_config)
    except json.JSONDecodeError as e:
        raise ValueError(f"Invalid JSON for --compilation_config: {e}") from e

)

sampling_params = SamplingParams(temperature=args.temp, max_tokens=256)
Expand Down
24 changes: 23 additions & 1 deletion vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional

import torch
import torch.nn as nn

Expand Down Expand Up @@ -74,6 +76,7 @@ def __init__(
1,
device=device,
dtype=torch.int32)
self.draft_attn_metadata = None

def propose(
self,
Expand Down Expand Up @@ -169,6 +172,13 @@ def propose(
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states

# copy attention metadata for full cudagraph mode
if self.draft_attn_metadata is not None and num_tokens <= self.cudagraph_batch_sizes[-1]:
self.draft_attn_metadata.seq_lens[:attn_metadata.seq_lens.shape[0]].copy_(attn_metadata.seq_lens.clone())
self.draft_attn_metadata.slot_mapping[:attn_metadata.slot_mapping.shape[0]].copy_(attn_metadata.slot_mapping.clone())
self.draft_attn_metadata.query_start_loc[:attn_metadata.query_start_loc.shape[0]].copy_(attn_metadata.query_start_loc.clone())
self.draft_attn_metadata.block_table[:attn_metadata.block_table.shape[0]].copy_(attn_metadata.block_table.clone())

with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
Expand Down Expand Up @@ -254,6 +264,13 @@ def propose(
self.positions[:batch_size] = clamped_positions
self.hidden_states[:batch_size] = hidden_states

# copy attention metadata for full cudagraph mode
if self.draft_attn_metadata is not None:
self.draft_attn_metadata.seq_lens[:attn_metadata.seq_lens.shape[0]].copy_(attn_metadata.seq_lens.clone())
self.draft_attn_metadata.slot_mapping[:attn_metadata.slot_mapping.shape[0]].copy_(attn_metadata.slot_mapping.clone())
self.draft_attn_metadata.query_start_loc[:attn_metadata.query_start_loc.shape[0]].copy_(attn_metadata.query_start_loc.clone())
self.draft_attn_metadata.block_table[:attn_metadata.block_table.shape[0]].copy_(attn_metadata.block_table.clone())

# Run the model.
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
Expand Down Expand Up @@ -369,8 +386,13 @@ def load_model(self, target_model: nn.Module) -> None:
def dummy_run(
self,
num_tokens: int,
attn_metadata: Optional[dict[str, Any]],
) -> None:
with set_forward_context(None, self.vllm_config,
if attn_metadata is not None and self.draft_attn_metadata is None:
attn_metadata[self.attn_layer_names[0]].scheduler_metadata = None
self.draft_attn_metadata = attn_metadata[self.attn_layer_names[0]] # assume only one draft layer
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_tokens):
self.model(
self.input_ids[:num_tokens],
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1860,7 +1860,7 @@ def maybe_randomize_inputs(self, input_ids: torch.Tensor):
Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
This is to help balance expert-selection
- during profile_run
- during DP rank dummy run
- during DP rank dummy run
"""
dp_size = self.vllm_config.parallel_config.data_parallel_size
randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1
Expand Down Expand Up @@ -1982,7 +1982,7 @@ def _dummy_run(

if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
self.drafter.dummy_run(num_tokens)
self.drafter.dummy_run(num_tokens, attn_metadata)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's my hypothesis:

  • the attn_metadata contains tensors
  • cudagraphs is baking in the addresses of those tensors
  • during runtime, the captured cudagraphs still read from these tensors.

Does the eagle forward pass use the tensors in the attn_metadata? If so, every time we invoke the eagle head, we may need to copy data into the tensors in the attn_metadata.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right this is partially the reason for the numerical gap. As an experiment I copied over the attn_metadata constructed for eager mode into the captured attn_metadata in latest commit:

# copy attention metadata for full cudagraph mode
if self.draft_attn_metadata is not None:
    self.draft_attn_metadata.seq_lens[:attn_metadata.seq_lens.shape[0]].copy_(attn_metadata.seq_lens.clone())
    self.draft_attn_metadata.slot_mapping[:attn_metadata.slot_mapping.shape[0]].copy_(attn_metadata.slot_mapping.clone())
    self.draft_attn_metadata.query_start_loc[:attn_metadata.query_start_loc.shape[0]].copy_(attn_metadata.query_start_loc.clone())
    self.draft_attn_metadata.block_table[:attn_metadata.block_table.shape[0]].copy_(attn_metadata.block_table.clone())

As a result, I got better numerics but there is still a gap comparing with piecewise mode:

  • VLLM_USE_V1=1 python examples/offline_inference/eagle.py --num_spec_tokens 7 --num_prompts 1 --compilation_config '{"full_cuda_graph": true, "cudagraph_capture_sizes": [1]}'
--------------------------------------------------
mean acceptance length: 2.46
--------------------------------------------------
acceptance at token 0:0.69
acceptance at token 1:0.38
acceptance at token 2:0.20
acceptance at token 3:0.12
acceptance at token 4:0.06
acceptance at token 5:0.00
acceptance at token 6:0.00
  • VLLM_USE_V1=1 python examples/offline_inference/eagle.py --num_spec_tokens 7 --num_prompts 1 --compilation_config '{"full_cuda_graph": false, "cudagraph_capture_sizes": [1]}'
--------------------------------------------------
mean acceptance length: 2.82
--------------------------------------------------
acceptance at token 0:0.77
acceptance at token 1:0.51
acceptance at token 2:0.28
acceptance at token 3:0.13
acceptance at token 4:0.05
acceptance at token 5:0.03
acceptance at token 6:0.03

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it seems there might still be some discrepancy in attention computation between eager mode and cudagraph mode. Will try to investigate more and would also appreciate if you have any suggestions to check from torch.compile perspective


logit_indices = np.cumsum(num_scheduled_tokens) - 1
return hidden_states, hidden_states[logit_indices]
Expand Down