Skip to content

Commit a0dd7dc

Browse files
authored
[TPU][V1] Fix Sampler recompilation (#15309)
Signed-off-by: NickLucche <nlucches@redhat.com>
1 parent e977c11 commit a0dd7dc

File tree

2 files changed

+84
-127
lines changed

2 files changed

+84
-127
lines changed

vllm/v1/sample/tpu/metadata.py

Lines changed: 71 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,18 @@
55
import torch
66
import torch_xla.core.xla_model as xm
77

8-
from vllm.v1.sample.metadata import SamplingMetadata
8+
from vllm.v1.worker.gpu_input_batch import InputBatch
9+
10+
DEFAULT_SAMPLING_PARAMS = dict(
11+
temperature=-1.0,
12+
min_p=0.0,
13+
# strictly disabled for now
14+
# top_k=-1,
15+
# top_p=0.0,
16+
# frequency_penalties=0.0,
17+
# presence_penalties=0.0,
18+
# repetition_penalties=0.0,
19+
)
920

1021

1122
@dataclass
@@ -20,14 +31,8 @@ class TPUSupportedSamplingMetadata:
2031
top_k: torch.Tensor = None
2132
top_p: torch.Tensor = None
2233

23-
# XLA-unfriendly control flow in Sampler
24-
all_greedy: bool = False
25-
all_random: bool = False
2634
# Greedy sampling flag for compiling single xla graph.
27-
do_argmax: torch.Tensor = None
28-
29-
# speculation not supported
30-
spec_token_ids = None
35+
all_greedy: torch.Tensor = None
3136

3237
# Generator not supported by xla
3338
generators: dict[int,
@@ -54,106 +59,68 @@ class TPUSupportedSamplingMetadata:
5459
bad_words_token_ids = None
5560
indices_do_sample: torch.Tensor = None
5661

57-
def __post_init__(self):
58-
temp = self.temperature
59-
if self.indices_do_sample is None:
60-
self.indices_do_sample = torch.zeros(temp.shape[0],
61-
device=temp.device,
62-
dtype=torch.int32)
63-
if self.do_argmax is None:
64-
self.do_argmax = torch.tensor(0,
65-
dtype=torch.bool,
66-
device=temp.device)
67-
6862
@classmethod
69-
def from_sampling_metadata(
70-
cls, metadata: SamplingMetadata,
71-
padded_do_sample_indices: torch.Tensor, num_do_sample: int,
72-
device: torch.device) -> "TPUSupportedSamplingMetadata":
63+
def from_input_batch(
64+
cls, input_batch: InputBatch,
65+
indices_do_sample: torch.Tensor) -> "TPUSupportedSamplingMetadata":
7366
"""
74-
Create an XLA-frienly SamplingMetadata structure. Do so by first
75-
instantiating an object with fixed-sized tensors and then writing the
76-
values in input `metadata`. Do that only for non-None values so that
77-
recompilation is not triggered for optional values (None/torch.Tensor).
78-
79-
In order to handle different sizes for the params that range from 1 up
80-
to `max_num_seqs`, pad tensors to the closest pre-compiled shape.
81-
Same thing for `padded_do_sample_indices`, which contains the indices
82-
to be fed to the Sampler, padded to the closest pre-compiled shape.
83-
84-
Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0]
85-
do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0]
67+
Copy sampling tensors slices from `input_batch` to on device tensors.
68+
69+
`InputBatch._make_sampling_metadata` causes recompilation on XLA as it
70+
slices dynamic shapes on device tensors. This impl moves the dynamic
71+
ops to CPU and produces tensors of fixed `padded_num_reqs` size. It
72+
also reuses the on-device persistent tensors managed in `input_batch`
73+
to reduce waste.
74+
75+
`indices_do_sample` contains the indices to be fed to the Sampler,
76+
normally one per request, here padded to the closest pre-compiled shape
77+
We expect sampling params tensors to be padded to the same fixed shape.
78+
79+
Eg. 3 requests, tensors padded to 4
80+
temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0]
81+
sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0]
8682
"""
87-
metadata = cls._validate_sampling_metadata(metadata)
88-
# NOTE we have to initialize default tensor-based params first and
89-
# skip None values altogether to produce the same xla graph.
90-
num_samples = len(padded_do_sample_indices)
91-
do_argmax = torch.tensor(metadata.all_greedy,
92-
dtype=torch.bool,
93-
device=device)
94-
new_metadata = cls.get_default_sampling_params(num_samples, device,
95-
indices_do_sample=\
96-
padded_do_sample_indices,
97-
do_argmax=do_argmax
98-
)
99-
supported_params = \
100-
TPUSupportedSamplingMetadata._get_default_params_values()
101-
# Copy input non-None values into `new_metadata` fixed-sized tensors.
102-
for p_name in supported_params:
103-
old_val = getattr(metadata, p_name)
104-
new_val = getattr(new_metadata, p_name)
105-
if isinstance(old_val, torch.Tensor):
106-
new_val[:num_do_sample] = old_val
107-
setattr(new_metadata, p_name, new_val)
83+
num_reqs = input_batch.num_reqs
84+
padded_num_reqs = len(indices_do_sample)
85+
86+
def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor,
87+
fill_val) -> torch.Tensor:
88+
# Copy slice from CPU to corresponding TPU pre-allocated tensor.
89+
# Pad value is the default one.
90+
cpu_tensor[num_reqs:padded_num_reqs] = fill_val
91+
tpu_tensor[:padded_num_reqs] = cpu_tensor[:padded_num_reqs]
92+
93+
# NOTE NickLucche The sync CPU-TPU graph we produce here must be
94+
# consistent. We can't have flags to skip copies or we'll end up
95+
# recompiling.
96+
copy_slice(input_batch.temperature_cpu_tensor, input_batch.temperature,
97+
DEFAULT_SAMPLING_PARAMS["temperature"])
98+
# TODO Temporarily disabled until sampling options are enabled
99+
# copy_slice(input_batch.top_p_cpu_tensor, input_batch.top_p)
100+
# copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k)
101+
copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p,
102+
DEFAULT_SAMPLING_PARAMS["min_p"])
103+
104+
# copy_slice(input_batch.frequency_penalties_cpu_tensor,
105+
# input_batch.frequency_penalties)
106+
# copy_slice(input_batch.presence_penalties_cpu_tensor,
107+
# input_batch.presence_penalties)
108+
# copy_slice(input_batch.repetition_penalties_cpu_tensor,
109+
# input_batch.repetition_penalties)
108110

109111
xm.mark_step()
110112
xm.wait_device_ops()
111-
return new_metadata
112113

113-
@classmethod
114-
def get_default_sampling_params(
115-
cls,
116-
num_samples: int,
117-
device: torch.device,
118-
indices_do_sample=None,
119-
do_argmax=None) -> "TPUSupportedSamplingMetadata":
120-
# As sampling happens on a single traced graph, options
121-
# are "disabled" by having them evaluate to an Identity op.
122-
# Note that initialization is dependent on num_samples.
123-
sampling_metadata_disable_value = \
124-
TPUSupportedSamplingMetadata._get_default_params_values()
125-
init_kwargs = dict()
126-
for p_name, (default_val,
127-
dtype) in sampling_metadata_disable_value.items():
128-
default_tensor = torch.full((num_samples, ),
129-
default_val,
130-
dtype=dtype,
131-
device=device)
132-
init_kwargs[p_name] = default_tensor
133-
134-
return cls(**init_kwargs,
135-
indices_do_sample=indices_do_sample,
136-
do_argmax=do_argmax)
137-
138-
@staticmethod
139-
def _validate_sampling_metadata(
140-
sampling_metadata: SamplingMetadata) -> SamplingMetadata:
141-
if sampling_metadata.all_greedy:
142-
# Set to None since #13587. Make sure default isn't overruled.
143-
assert sampling_metadata.temperature is None
144-
return sampling_metadata
145-
146-
@staticmethod
147-
def _get_default_params_values():
148-
return dict(
149-
# Since #13587 greedy sampling requires branching off which leads
150-
# to separate graphs. We set temp to noop and handle argmax here.
151-
temperature=(1.0, torch.float32),
152-
min_p=(0.0, torch.float32),
153-
# strictly disabled for now
154-
# top_k=(-1, torch.int32),
155-
# top_p=(0.0, torch.float32),
156-
# frequency_penalties=(0.0, torch.float32),
157-
# presence_penalties=(0.0, torch.float32),
158-
# repetition_penalties=(0.0, torch.float32),
159-
)
114+
# Slice persistent device tensors to a fixed pre-compiled padded shape.
115+
return cls(
116+
temperature=input_batch.temperature[:padded_num_reqs],
117+
# Scalar tensor for xla-friendly tracing.
118+
all_greedy=torch.tensor(input_batch.all_greedy,
119+
dtype=torch.bool,
120+
device=input_batch.device),
121+
# TODO enable more and avoid returning None values
122+
top_p=None, # input_batch.top_p[:padded_num_reqs],
123+
top_k=None, # input_batch.top_k[:padded_num_reqs],
124+
min_p=input_batch.min_p[:padded_num_reqs],
125+
generators=input_batch.generators,
126+
indices_do_sample=indices_do_sample)

vllm/v1/worker/tpu_model_runner.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
279279
req_data.num_computed_tokens)
280280
self.input_batch.block_table.append_row(req_data.new_block_ids,
281281
req_index)
282-
# Check if the batch has changed. If not, we can skip copying the
283-
# sampling metadata from CPU to GPU.
284-
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
285282

286283
# Add the new or resumed requests to the persistent batch.
287284
# The smaller empty indices are filled first.
@@ -300,9 +297,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
300297
if removed_req_indices:
301298
self.input_batch.condense(removed_req_indices)
302299

303-
# TODO This slices tensors to copy to device, triggering recompilation.
304-
if batch_changed:
305-
self.input_batch.refresh_sampling_metadata()
306300
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
307301

308302
def get_model(self) -> nn.Module:
@@ -597,14 +591,12 @@ def execute_model(
597591
# then the embedding layer is not included in the CUDA graph.
598592
input_ids = self.input_ids
599593
inputs_embeds = None
600-
sampling_metadata = self.input_batch.sampling_metadata
601594
num_reqs = self.input_batch.num_reqs
602-
# NOTE (NickLucche) here we sync with TPU: if there's any shape
603-
# mismatch in pre-processing, it will trigger a small recompilation
604-
# of the code thus far. Forward graph remains untouched.
595+
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
596+
# are copied to device in chunks of pre-compiled padded shape to
597+
# avoid recompilations.
605598
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
606-
from_sampling_metadata(sampling_metadata, logits_indices,
607-
num_reqs, self.device)
599+
from_input_batch(self.input_batch, logits_indices)
608600
# Run the decoder
609601
with set_forward_context(attn_metadata, self.vllm_config):
610602
hidden_states = self.model(
@@ -797,21 +789,19 @@ def capture_model(self) -> None:
797789
device=device,
798790
dtype=torch.bfloat16)
799791
while True:
800-
# Default metadata is an all_greedy setup. But since the
801-
# `do_argmax` flag is a tensor, we still compile the full graph
802-
meta = self.input_batch.sampling_metadata
803792
indices = torch.zeros(
804793
num_reqs_to_sample,
805794
dtype=torch.int32,
806795
device=device,
807796
)
797+
xm.mark_step()
808798
sampling_meta = TPUSupportedSamplingMetadata.\
809-
from_sampling_metadata(meta, indices,
810-
num_reqs_to_sample, device)
799+
from_input_batch(self.input_batch, indices)
811800
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
812801
num_reqs_to_sample)
813-
self.model.sample_from_hidden(dummy_hidden, sampling_meta)
814-
xm.mark_step()
802+
out = self.model.sample_from_hidden(dummy_hidden,
803+
sampling_meta)
804+
out = out.cpu()
815805
if num_reqs_to_sample >= self.max_num_reqs:
816806
break
817807
num_reqs_to_sample *= 2
@@ -910,6 +900,7 @@ def forward(
910900

911901
return hidden_states
912902

903+
# @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
913904
def sample_from_hidden(
914905
self,
915906
hidden_states: torch.Tensor,
@@ -923,10 +914,9 @@ def sample_from_hidden(
923914
sample_hidden_states = \
924915
hidden_states[sampling_metadata.indices_do_sample]
925916
logits = self.compute_logits(sample_hidden_states)
926-
# Greedy sampling can't be run without branching the graph on Sampler.
927-
# Therefore do_argmax/all_greedy is checked here in a xla-friendly way.
928-
# NOTE do_argmax is a scalar, this is just an optimized if/else.
929-
out_tokens = torch.where(sampling_metadata.do_argmax,
917+
# Optimized greedy sampling branch, tracing both paths in a single pass
918+
# NOTE all_greedy is a scalar, this is just an optimized if/else.
919+
out_tokens = torch.where(sampling_metadata.all_greedy,
930920
torch.argmax(logits, dim=-1, keepdim=True),
931921
self.sample(logits, sampling_metadata)\
932922
.sampled_token_ids)

0 commit comments

Comments
 (0)