Skip to content

[CB][do not merge] Support batch size 1 for decode, simplify warmup #312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/e2e/test_spyre_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
ids=lambda val: f"TP({val})",
)
@pytest.mark.parametrize("backend", get_spyre_backend_list())
@pytest.mark.parametrize("max_num_seqs", [4],
@pytest.mark.parametrize("max_num_seqs", [1, 4],
ids=lambda val: f"max_num_seqs({val})")
def test_output(
model: str,
Expand Down
22 changes: 0 additions & 22 deletions vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,11 +682,6 @@ def __init__(
super().__init__(vllm_config=vllm_config,
is_driver_worker=is_driver_worker)

# TODO: remove this limitation once we update the warm-up logic to
# support batch_size=1
assert vllm_config.scheduler_config.max_num_seqs >= 2, "Currently, " \
"continuous batching needs config to set batch_size >= 2"

self.block_size = SpyrePlatform.get_block_size()

# TODO: move to a KV cache manager
Expand Down Expand Up @@ -978,23 +973,6 @@ def _prepare_decode(
current_tkv_mask = torch.tensor([self.tkv] * len(input_tokens),
dtype=torch.int64)

# add pads for min decode batch size of 2 (Spyre compiler constraint)
if len(cached_request_data.req_ids) == 1:
padd_seq_indices = torch.zeros(1, dtype=torch.bool, device="cpu")
self.model.indices = torch.cat(
(self.model.indices, padd_seq_indices), -1)
assert self.model.indices.size(dim=0) == 2

input_tokens = torch.cat(2 * [input_tokens])
position_ids = torch.cat(2 * [position_ids])
current_tkv_mask = torch.cat(2 * [current_tkv_mask])
left_padded_prompt_mask = torch.cat(2 * [left_padded_prompt_mask])
block_table = torch.cat(2 * [block_table])
slot_mapping = torch.cat(2 * [slot_mapping])

# assert min batch size 2 for decodes (Spyre compiler constraint)
assert len(input_tokens) >= 2

model_inputs = ModelForwardInputs(
input_tokens=input_tokens,
input_positions=position_ids,
Expand Down
25 changes: 25 additions & 0 deletions vllm_spyre/v1/worker/spyre_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import vllm_spyre.perf_metrics as perf_metrics
from vllm_spyre.model_executor.model_loader import spyre_setup
from vllm_spyre.platform import SpyrePlatform
from vllm_spyre.v1.worker.spyre_input_batch import InputBatch
from vllm_spyre.v1.worker.spyre_model_runner import (
ContinuousBatchingSpyreModelRunner, StaticBatchingSpyreModelRunner)

Expand Down Expand Up @@ -321,6 +322,17 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
prompt_len = 42
num_decode_tokens = 2

# Fix for batch size 1: set input batch to fit 2 requests for warmup
if model_runner.vllm_config.scheduler_config.max_num_seqs == 1:
model_runner.input_batch = InputBatch(
max_num_reqs=2,
max_model_len=model_runner.vllm_config.model_config.
max_model_len,
device=model_runner.device,
pin_memory=model_runner.pin_memory,
vocab_size=model_runner.vllm_config.model_config.
get_vocab_size(),
)
# Sample from the valid token ids
warmup_tokens_tensor = valid_token_ids_tensor[torch.randint(
0, len(valid_token_ids_tensor), (batch_size + 1, prompt_len))]
Expand Down Expand Up @@ -368,6 +380,19 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
self.execute_model(scheduler_output)
self._cleanup_model_runner(request=[add_dummy_request])

# Fix for batch size 1: reset input batch to fit max_num_seqs requests
if model_runner.vllm_config.scheduler_config.max_num_seqs == 1:
model_runner.input_batch = InputBatch(
max_num_reqs=model_runner.vllm_config.scheduler_config.
max_num_seqs,
max_model_len=model_runner.vllm_config.model_config.
max_model_len,
device=model_runner.device,
pin_memory=model_runner.pin_memory,
vocab_size=model_runner.vllm_config.model_config.
get_vocab_size(),
)

model_runner.finish_warmup()

warmup_end_t = time.time()
Expand Down
Loading