1
1
# SPDX-License-Identifier: Apache-2.0
2
+ import bisect
2
3
import time
3
4
from typing import TYPE_CHECKING , Optional , cast
4
5
from unittest .mock import patch
@@ -170,6 +171,10 @@ def __init__(
170
171
# Range tensor with values [0 .. self.max_num_tokens - 1].
171
172
# Used to initialize positions / context_lens / seq_lens
172
173
self .arange_np = np .arange (self .max_num_tokens , dtype = np .int32 )
174
+ self .num_tokens_paddings = _get_paddings (
175
+ min_token_size = 16 ,
176
+ max_token_size = self .max_num_tokens ,
177
+ padding_gap = envs .VLLM_TPU_BUCKET_PADDING_GAP )
173
178
174
179
def _update_states (self , scheduler_output : "SchedulerOutput" ) -> bool :
175
180
"""Update the cached states and the persistent batch with the scheduler
@@ -422,7 +427,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
422
427
423
428
# Do the padding and copy the tensors to the TPU.
424
429
padded_total_num_scheduled_tokens = _get_padded_token_len (
425
- total_num_scheduled_tokens )
430
+ self . num_tokens_paddings , total_num_scheduled_tokens )
426
431
# Zero out to avoid spurious values from prev iteration (last cp chunk)
427
432
self .input_ids_cpu [
428
433
total_num_scheduled_tokens :padded_total_num_scheduled_tokens ] = 0
@@ -573,7 +578,6 @@ def execute_model(
573
578
574
579
# Prepare inputs
575
580
attn_metadata , logits_indices = self ._prepare_inputs (scheduler_output )
576
-
577
581
if self .is_multimodal_model :
578
582
# NOTE(woosuk): To unify token ids and soft tokens (vision
579
583
# embeddings), we always use embeddings (rather than token ids)
@@ -764,26 +768,21 @@ def capture_model(self) -> None:
764
768
logger .info ("Compiling the model with different input shapes." )
765
769
766
770
start = time .perf_counter ()
767
- num_tokens = 16
768
- while True :
771
+ for num_tokens in self .num_tokens_paddings :
769
772
logger .info (" -- num_tokens: %d" , num_tokens )
770
773
self ._dummy_run (self .kv_caches , num_tokens )
771
774
xm .mark_step ()
772
- if num_tokens >= self .max_num_tokens :
773
- break
774
- num_tokens *= 2
775
775
xm .wait_device_ops ()
776
776
end = time .perf_counter ()
777
777
logger .info ("Compilation finished in in %.2f [secs]." , end - start )
778
778
779
779
logger .info ("Compiling sampling with different input shapes." )
780
780
start = time .perf_counter ()
781
- num_tokens = 16
782
781
hsize = self .model_config .get_hidden_size ()
783
782
device = self .device
784
783
# Compile sampling step for different model+sampler outputs in bucketed
785
784
# n_tokens x max_num_reqs. Graph is really small so this is fine.
786
- while True :
785
+ for num_tokens in self . num_tokens_paddings :
787
786
num_reqs_to_sample = MIN_NUM_SEQS
788
787
dummy_hidden = torch .randn ((num_tokens , hsize ),
789
788
device = device ,
@@ -805,9 +804,6 @@ def capture_model(self) -> None:
805
804
if num_reqs_to_sample >= self .max_num_reqs :
806
805
break
807
806
num_reqs_to_sample *= 2
808
- if num_tokens >= self .max_num_tokens :
809
- break
810
- num_tokens *= 2
811
807
xm .wait_device_ops ()
812
808
end = time .perf_counter ()
813
809
logger .info ("Compilation finished in in %.2f [secs]." , end - start )
@@ -939,12 +935,33 @@ def _get_padded_number(n: int, multiple: int) -> int:
939
935
return ((n + multiple - 1 ) // multiple ) * multiple
940
936
941
937
942
- def _get_padded_token_len (x : int ) -> int :
943
- if x <= 16 :
944
- return 16
945
- return 1 << (x - 1 ).bit_length ()
946
-
947
-
948
938
def _get_padded_num_reqs_with_upper_limit (x , upper_limit ) -> int :
949
939
res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1 ).bit_length ()
950
940
return min (res , upper_limit )
941
+
942
+
943
+ def _get_paddings (min_token_size : int , max_token_size : int ,
944
+ padding_gap : int ) -> list [int ]:
945
+ """Generate a list of padding size, starting from min_token_size,
946
+ ending with a number that can cover max_token_size
947
+ first increase the size to twice,
948
+ then increase the padding size by padding_gap.
949
+ """
950
+ paddings = []
951
+ num = min_token_size
952
+ while num <= padding_gap :
953
+ paddings .append (num )
954
+ num *= 2
955
+ num //= 2
956
+ while num < max_token_size :
957
+ num += padding_gap
958
+ paddings .append (num )
959
+ return paddings
960
+
961
+
962
+ def _get_padded_token_len (paddings : list [int ], x : int ) -> int :
963
+ """Return the first element in paddings list greater or equal to x.
964
+ """
965
+ index = bisect .bisect_left (paddings , x )
966
+ assert index < len (paddings )
967
+ return paddings [index ]
0 commit comments