File tree Expand file tree Collapse file tree 3 files changed +12
-4
lines changed
examples/offline_inference Expand file tree Collapse file tree 3 files changed +12
-4
lines changed Original file line number Diff line number Diff line change @@ -48,6 +48,7 @@ def parse_args():
48
48
parser .add_argument ("--enable_chunked_prefill" , action = "store_true" )
49
49
parser .add_argument ("--max_num_batched_tokens" , type = int , default = 2048 )
50
50
parser .add_argument ("--temp" , type = float , default = 0 )
51
+ parser .add_argument ("--compilation_config" , type = str , default = "" )
51
52
return parser .parse_args ()
52
53
53
54
@@ -94,6 +95,9 @@ def main():
94
95
"max_model_len" : max_model_len ,
95
96
},
96
97
disable_log_stats = False ,
98
+ compilation_config = (
99
+ json .loads (args .compilation_config ) if args .compilation_config else None
100
+ ),
97
101
)
98
102
99
103
sampling_params = SamplingParams (temperature = args .temp , max_tokens = 256 )
Original file line number Diff line number Diff line change 1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ from typing import Any , Optional
4
+
3
5
import torch
4
6
import torch .nn as nn
5
7
@@ -169,7 +171,7 @@ def propose(
169
171
self .positions [:num_tokens ] = target_positions
170
172
self .hidden_states [:num_tokens ] = target_hidden_states
171
173
172
- with set_forward_context (per_layer_attn_metadata ,
174
+ with set_forward_context (None ,
173
175
self .vllm_config ,
174
176
num_tokens = num_input_tokens ):
175
177
ret_hidden_states = self .model (
@@ -369,8 +371,10 @@ def load_model(self, target_model: nn.Module) -> None:
369
371
def dummy_run (
370
372
self ,
371
373
num_tokens : int ,
374
+ attn_metadata : Optional [dict [str , Any ]],
372
375
) -> None :
373
- with set_forward_context (None , self .vllm_config ,
376
+ with set_forward_context (attn_metadata ,
377
+ self .vllm_config ,
374
378
num_tokens = num_tokens ):
375
379
self .model (
376
380
self .input_ids [:num_tokens ],
Original file line number Diff line number Diff line change @@ -1860,7 +1860,7 @@ def maybe_randomize_inputs(self, input_ids: torch.Tensor):
1860
1860
Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
1861
1861
This is to help balance expert-selection
1862
1862
- during profile_run
1863
- - during DP rank dummy run
1863
+ - during DP rank dummy run
1864
1864
"""
1865
1865
dp_size = self .vllm_config .parallel_config .data_parallel_size
1866
1866
randomize_inputs = envs .VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1
@@ -1982,7 +1982,7 @@ def _dummy_run(
1982
1982
1983
1983
if self .speculative_config and self .speculative_config .use_eagle ():
1984
1984
assert isinstance (self .drafter , EagleProposer )
1985
- self .drafter .dummy_run (num_tokens )
1985
+ self .drafter .dummy_run (num_tokens , attn_metadata )
1986
1986
1987
1987
logit_indices = np .cumsum (num_scheduled_tokens ) - 1
1988
1988
return hidden_states , hidden_states [logit_indices ]
You can’t perform that action at this time.
0 commit comments