Skip to content

Commit ac3cd6e

Browse files
[core] add bucket padding to tpu_model_runner (#14995)
Signed-off-by: Chenyaaang <llccyy1212@gmail.com> Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com> Co-authored-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
1 parent 082ab86 commit ac3cd6e

File tree

3 files changed

+63
-19
lines changed

3 files changed

+63
-19
lines changed

tests/v1/tpu/worker/test_tpu_model_runner.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
99
SchedulerOutput)
1010
from vllm.v1.sample.metadata import SamplingMetadata
11-
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
11+
from vllm.v1.worker.tpu_model_runner import (TPUModelRunner,
12+
_get_padded_token_len,
13+
_get_paddings)
1214

1315
# Mock torch_xla module since it may not be available in the test environments
1416
torch_xla_patcher = mock.patch.dict(
@@ -305,3 +307,21 @@ def test_update_states_request_unscheduled(model_runner):
305307

306308
assert _is_req_added(model_runner, req_ids[1])
307309
assert not _is_req_scheduled(model_runner, req_ids[1])
310+
311+
312+
def test_get_paddings():
313+
min_token_size, max_token_size, padding_gap = 16, 512, 64
314+
expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]
315+
actual_paddings = _get_paddings(min_token_size, max_token_size,
316+
padding_gap)
317+
assert actual_paddings == expected_paddings
318+
319+
320+
def test_get_padded_token_len():
321+
min_token_size, max_token_size, padding_gap = 16, 512, 64
322+
paddings = _get_paddings(min_token_size, max_token_size, padding_gap)
323+
assert _get_padded_token_len(paddings, 1) == 16
324+
assert _get_padded_token_len(paddings, 16) == 16
325+
assert _get_padded_token_len(paddings, 20) == 32
326+
assert _get_padded_token_len(paddings, 300) == 320
327+
assert _get_padded_token_len(paddings, 512) == 512

vllm/envs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
9898
VLLM_V0_USE_OUTLINES_CACHE: bool = False
9999
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
100+
VLLM_TPU_BUCKET_PADDING_GAP: int = 64
100101

101102

102103
def get_default_cache_root():
@@ -627,6 +628,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
627628
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION":
628629
lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"]))
629630
if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None,
631+
632+
# Gap between padding buckets for the forward pass. So we have
633+
# 8, we will run forward pass with [16, 24, 32, ...].
634+
"VLLM_TPU_BUCKET_PADDING_GAP":
635+
lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"])
636+
if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 64,
630637
}
631638

632639
# end-env-vars-definition

vllm/v1/worker/tpu_model_runner.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
import bisect
23
import time
34
from typing import TYPE_CHECKING, Optional, cast
45
from unittest.mock import patch
@@ -170,6 +171,10 @@ def __init__(
170171
# Range tensor with values [0 .. self.max_num_tokens - 1].
171172
# Used to initialize positions / context_lens / seq_lens
172173
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
174+
self.num_tokens_paddings = _get_paddings(
175+
min_token_size=16,
176+
max_token_size=self.max_num_tokens,
177+
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
173178

174179
def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
175180
"""Update the cached states and the persistent batch with the scheduler
@@ -422,7 +427,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
422427

423428
# Do the padding and copy the tensors to the TPU.
424429
padded_total_num_scheduled_tokens = _get_padded_token_len(
425-
total_num_scheduled_tokens)
430+
self.num_tokens_paddings, total_num_scheduled_tokens)
426431
# Zero out to avoid spurious values from prev iteration (last cp chunk)
427432
self.input_ids_cpu[
428433
total_num_scheduled_tokens:padded_total_num_scheduled_tokens] = 0
@@ -573,7 +578,6 @@ def execute_model(
573578

574579
# Prepare inputs
575580
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
576-
577581
if self.is_multimodal_model:
578582
# NOTE(woosuk): To unify token ids and soft tokens (vision
579583
# embeddings), we always use embeddings (rather than token ids)
@@ -764,26 +768,21 @@ def capture_model(self) -> None:
764768
logger.info("Compiling the model with different input shapes.")
765769

766770
start = time.perf_counter()
767-
num_tokens = 16
768-
while True:
771+
for num_tokens in self.num_tokens_paddings:
769772
logger.info(" -- num_tokens: %d", num_tokens)
770773
self._dummy_run(self.kv_caches, num_tokens)
771774
xm.mark_step()
772-
if num_tokens >= self.max_num_tokens:
773-
break
774-
num_tokens *= 2
775775
xm.wait_device_ops()
776776
end = time.perf_counter()
777777
logger.info("Compilation finished in in %.2f [secs].", end - start)
778778

779779
logger.info("Compiling sampling with different input shapes.")
780780
start = time.perf_counter()
781-
num_tokens = 16
782781
hsize = self.model_config.get_hidden_size()
783782
device = self.device
784783
# Compile sampling step for different model+sampler outputs in bucketed
785784
# n_tokens x max_num_reqs. Graph is really small so this is fine.
786-
while True:
785+
for num_tokens in self.num_tokens_paddings:
787786
num_reqs_to_sample = MIN_NUM_SEQS
788787
dummy_hidden = torch.randn((num_tokens, hsize),
789788
device=device,
@@ -805,9 +804,6 @@ def capture_model(self) -> None:
805804
if num_reqs_to_sample >= self.max_num_reqs:
806805
break
807806
num_reqs_to_sample *= 2
808-
if num_tokens >= self.max_num_tokens:
809-
break
810-
num_tokens *= 2
811807
xm.wait_device_ops()
812808
end = time.perf_counter()
813809
logger.info("Compilation finished in in %.2f [secs].", end - start)
@@ -939,12 +935,33 @@ def _get_padded_number(n: int, multiple: int) -> int:
939935
return ((n + multiple - 1) // multiple) * multiple
940936

941937

942-
def _get_padded_token_len(x: int) -> int:
943-
if x <= 16:
944-
return 16
945-
return 1 << (x - 1).bit_length()
946-
947-
948938
def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int:
949939
res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
950940
return min(res, upper_limit)
941+
942+
943+
def _get_paddings(min_token_size: int, max_token_size: int,
944+
padding_gap: int) -> list[int]:
945+
"""Generate a list of padding size, starting from min_token_size,
946+
ending with a number that can cover max_token_size
947+
first increase the size to twice,
948+
then increase the padding size by padding_gap.
949+
"""
950+
paddings = []
951+
num = min_token_size
952+
while num <= padding_gap:
953+
paddings.append(num)
954+
num *= 2
955+
num //= 2
956+
while num < max_token_size:
957+
num += padding_gap
958+
paddings.append(num)
959+
return paddings
960+
961+
962+
def _get_padded_token_len(paddings: list[int], x: int) -> int:
963+
"""Return the first element in paddings list greater or equal to x.
964+
"""
965+
index = bisect.bisect_left(paddings, x)
966+
assert index < len(paddings)
967+
return paddings[index]

0 commit comments

Comments
 (0)