|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
| 3 | +from typing import Any, Optional |
| 4 | + |
3 | 5 | import torch
|
4 | 6 | import torch.nn as nn
|
5 | 7 |
|
@@ -74,6 +76,7 @@ def __init__(
|
74 | 76 | 1,
|
75 | 77 | device=device,
|
76 | 78 | dtype=torch.int32)
|
| 79 | + self.draft_attn_metadata = None |
77 | 80 |
|
78 | 81 | def propose(
|
79 | 82 | self,
|
@@ -169,6 +172,13 @@ def propose(
|
169 | 172 | self.positions[:num_tokens] = target_positions
|
170 | 173 | self.hidden_states[:num_tokens] = target_hidden_states
|
171 | 174 |
|
| 175 | + # copy attention metadata for full cudagraph mode |
| 176 | + if self.draft_attn_metadata is not None and num_tokens <= self.cudagraph_batch_sizes[-1]: |
| 177 | + self.draft_attn_metadata.seq_lens[:attn_metadata.seq_lens.shape[0]].copy_(attn_metadata.seq_lens.clone()) |
| 178 | + self.draft_attn_metadata.slot_mapping[:attn_metadata.slot_mapping.shape[0]].copy_(attn_metadata.slot_mapping.clone()) |
| 179 | + self.draft_attn_metadata.query_start_loc[:attn_metadata.query_start_loc.shape[0]].copy_(attn_metadata.query_start_loc.clone()) |
| 180 | + self.draft_attn_metadata.block_table[:attn_metadata.block_table.shape[0]].copy_(attn_metadata.block_table.clone()) |
| 181 | + |
172 | 182 | with set_forward_context(per_layer_attn_metadata,
|
173 | 183 | self.vllm_config,
|
174 | 184 | num_tokens=num_input_tokens):
|
@@ -254,6 +264,13 @@ def propose(
|
254 | 264 | self.positions[:batch_size] = clamped_positions
|
255 | 265 | self.hidden_states[:batch_size] = hidden_states
|
256 | 266 |
|
| 267 | + # copy attention metadata for full cudagraph mode |
| 268 | + if self.draft_attn_metadata is not None: |
| 269 | + self.draft_attn_metadata.seq_lens[:attn_metadata.seq_lens.shape[0]].copy_(attn_metadata.seq_lens.clone()) |
| 270 | + self.draft_attn_metadata.slot_mapping[:attn_metadata.slot_mapping.shape[0]].copy_(attn_metadata.slot_mapping.clone()) |
| 271 | + self.draft_attn_metadata.query_start_loc[:attn_metadata.query_start_loc.shape[0]].copy_(attn_metadata.query_start_loc.clone()) |
| 272 | + self.draft_attn_metadata.block_table[:attn_metadata.block_table.shape[0]].copy_(attn_metadata.block_table.clone()) |
| 273 | + |
257 | 274 | # Run the model.
|
258 | 275 | with set_forward_context(per_layer_attn_metadata,
|
259 | 276 | self.vllm_config,
|
@@ -369,8 +386,13 @@ def load_model(self, target_model: nn.Module) -> None:
|
369 | 386 | def dummy_run(
|
370 | 387 | self,
|
371 | 388 | num_tokens: int,
|
| 389 | + attn_metadata: Optional[dict[str, Any]], |
372 | 390 | ) -> None:
|
373 |
| - with set_forward_context(None, self.vllm_config, |
| 391 | + if attn_metadata is not None and self.draft_attn_metadata is None: |
| 392 | + attn_metadata[self.attn_layer_names[0]].scheduler_metadata = None |
| 393 | + self.draft_attn_metadata = attn_metadata[self.attn_layer_names[0]] # assume only one draft layer |
| 394 | + with set_forward_context(attn_metadata, |
| 395 | + self.vllm_config, |
374 | 396 | num_tokens=num_tokens):
|
375 | 397 | self.model(
|
376 | 398 | self.input_ids[:num_tokens],
|
|
0 commit comments