Skip to content

[CB] Support pseudo batch size 1 for decode, adjust warmup #287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion tests/e2e/test_spyre_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
ids=lambda val: f"TP({val})",
)
@pytest.mark.parametrize("backend", get_spyre_backend_list())
@pytest.mark.parametrize("max_num_seqs", [4],
@pytest.mark.parametrize("max_num_seqs", [1, 4],
ids=lambda val: f"max_num_seqs({val})")
def test_output(
model: str,
Expand Down
3 changes: 2 additions & 1 deletion vllm_spyre/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# set env vars for torch_sendnn to consume
os.environ["VLLM_DT_MAX_CONTEXT_LEN"] = str(
vllm_config.model_config.max_model_len)
# min decode batch size is 2 due to symbolic shape constraint in torch
os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str(
vllm_config.scheduler_config.max_num_seqs)
max(vllm_config.scheduler_config.max_num_seqs, 2))

@classmethod
def use_all_gather(cls) -> bool:
Expand Down
5 changes: 0 additions & 5 deletions vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,11 +682,6 @@ def __init__(
super().__init__(vllm_config=vllm_config,
is_driver_worker=is_driver_worker)

# TODO: remove this limitation once we update the warm-up logic to
# support batch_size=1
assert vllm_config.scheduler_config.max_num_seqs >= 2, "Currently, " \
"continuous batching needs config to set batch_size >= 2"

self.block_size = SpyrePlatform.get_block_size()

# TODO: move to a KV cache manager
Expand Down
26 changes: 26 additions & 0 deletions vllm_spyre/v1/worker/spyre_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import vllm_spyre.perf_metrics as perf_metrics
from vllm_spyre.model_executor.model_loader import spyre_setup
from vllm_spyre.platform import SpyrePlatform
from vllm_spyre.v1.worker.spyre_input_batch import InputBatch
from vllm_spyre.v1.worker.spyre_model_runner import (
ContinuousBatchingSpyreModelRunner, StaticBatchingSpyreModelRunner)

Expand Down Expand Up @@ -321,6 +322,18 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
prompt_len = 42
num_decode_tokens = 2

# Fix for batch size 1: set input batch to fit 2 requests for warmup
if model_runner.vllm_config.scheduler_config.max_num_seqs == 1:
model_runner.input_batch = InputBatch(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, could the InputBatch construct itself with:

self.max_num_reqs = min(max_num_reqs, 2)

since we know that it'll always need at least 2, and then we avoid reconstructing it in the worker here? That way we have a much smaller diff to back out once we can lift this bs>=2 restriction

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if I follow here. it has to be >=2 for the warmup. with the min(1,2) we would still fail?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would that work if you directly set model_runner.input_batch.max_num_reqs = 2, instead of instantiating a new InputBatch?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, because InputBatch initialization gets model_runner.input_batch.max_num_reqs..

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_num_seqs occurs 17 times in the init of the InputBatch. It is not a single attribute, but used to construct several attributes. So re-initializing is simpler...

max_num_reqs=2,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it necessary to set the full input batch object instead of just setting model_runner.input_batch.max_num_reqs=2?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, this is not possible. please have a look at the comment here: #287 (comment)

max_model_len=model_runner.vllm_config.model_config.
max_model_len,
device=model_runner.device,
pin_memory=model_runner.pin_memory,
vocab_size=model_runner.vllm_config.model_config.
get_vocab_size(),
)

# Sample from the valid token ids
warmup_tokens_tensor = valid_token_ids_tensor[torch.randint(
0, len(valid_token_ids_tensor), (batch_size + 1, prompt_len))]
Expand Down Expand Up @@ -368,6 +381,19 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
self.execute_model(scheduler_output)
self._cleanup_model_runner(request=[add_dummy_request])

# Fix for batch size 1: reset input batch to fit max_num_seqs requests
if model_runner.vllm_config.scheduler_config.max_num_seqs == 1:
model_runner.input_batch = InputBatch(
max_num_reqs=model_runner.vllm_config.scheduler_config.
max_num_seqs,
max_model_len=model_runner.vllm_config.model_config.
max_model_len,
device=model_runner.device,
pin_memory=model_runner.pin_memory,
vocab_size=model_runner.vllm_config.model_config.
get_vocab_size(),
)

model_runner.finish_warmup()

warmup_end_t = time.time()
Expand Down
Loading