-
Notifications
You must be signed in to change notification settings - Fork 18
[CB] Support pseudo batch size 1 for decode, adjust warmup #287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a52c7ec
1797ea3
9561a9d
10aed51
78d864a
ff7b32a
55f4efa
b392e23
3959678
d55a23f
b2c96dd
c8ebad6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ | |
import vllm_spyre.perf_metrics as perf_metrics | ||
from vllm_spyre.model_executor.model_loader import spyre_setup | ||
from vllm_spyre.platform import SpyrePlatform | ||
from vllm_spyre.v1.worker.spyre_input_batch import InputBatch | ||
from vllm_spyre.v1.worker.spyre_model_runner import ( | ||
ContinuousBatchingSpyreModelRunner, StaticBatchingSpyreModelRunner) | ||
|
||
|
@@ -321,6 +322,18 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): | |
prompt_len = 42 | ||
num_decode_tokens = 2 | ||
|
||
# Fix for batch size 1: set input batch to fit 2 requests for warmup | ||
if model_runner.vllm_config.scheduler_config.max_num_seqs == 1: | ||
model_runner.input_batch = InputBatch( | ||
max_num_reqs=2, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is it necessary to set the full input batch object instead of just setting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no, this is not possible. please have a look at the comment here: #287 (comment) |
||
max_model_len=model_runner.vllm_config.model_config. | ||
max_model_len, | ||
device=model_runner.device, | ||
pin_memory=model_runner.pin_memory, | ||
vocab_size=model_runner.vllm_config.model_config. | ||
get_vocab_size(), | ||
) | ||
|
||
# Sample from the valid token ids | ||
warmup_tokens_tensor = valid_token_ids_tensor[torch.randint( | ||
0, len(valid_token_ids_tensor), (batch_size + 1, prompt_len))] | ||
|
@@ -368,6 +381,19 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): | |
self.execute_model(scheduler_output) | ||
self._cleanup_model_runner(request=[add_dummy_request]) | ||
|
||
# Fix for batch size 1: reset input batch to fit max_num_seqs requests | ||
if model_runner.vllm_config.scheduler_config.max_num_seqs == 1: | ||
model_runner.input_batch = InputBatch( | ||
max_num_reqs=model_runner.vllm_config.scheduler_config. | ||
max_num_seqs, | ||
max_model_len=model_runner.vllm_config.model_config. | ||
max_model_len, | ||
device=model_runner.device, | ||
pin_memory=model_runner.pin_memory, | ||
vocab_size=model_runner.vllm_config.model_config. | ||
get_vocab_size(), | ||
) | ||
|
||
model_runner.finish_warmup() | ||
|
||
warmup_end_t = time.time() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, could the
InputBatch
construct itself with:since we know that it'll always need at least 2, and then we avoid reconstructing it in the worker here? That way we have a much smaller diff to back out once we can lift this bs>=2 restriction
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if I follow here. it has to be >=2 for the warmup. with the min(1,2) we would still fail?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would that work if you directly set
model_runner.input_batch.max_num_reqs = 2
, instead of instantiating a newInputBatch
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, because
InputBatch
initialization getsmodel_runner.input_batch.max_num_reqs
..There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_num_seqs
occurs 17 times in the init of theInputBatch
. It is not a single attribute, but used to construct several attributes. So re-initializing is simpler...