Skip to content

[BugFix]: Batch generation from prompt_embeds fails for long prompts #21390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 24, 2025
36 changes: 22 additions & 14 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1932,24 +1932,32 @@ def execute_model(

if model_input.inputs_embeds is not None:
if self.is_driver_worker:
sampled = broadcast_tensor_dict(
{"token_ids": output.sampled_token_ids})
sampled_token_ids = []
valid_outputs = []
for sequence_group_output in output.outputs:
if len(sequence_group_output.samples) == 0:
continue
assert len(sequence_group_output.samples) == 1
valid_outputs.append(sequence_group_output)
sampled_token_ids.append(
sequence_group_output.samples[0].output_token)
sampled_token_ids = torch.tensor(sampled_token_ids).to(
self.device)
sampled_token_ids = broadcast_tensor_dict(
{"sampled_token_ids":
sampled_token_ids})["sampled_token_ids"]
else:
sampled = broadcast_tensor_dict()
if sampled["token_ids"] is not None:
sampled_token_embeds = self.model.get_input_embeddings(
sampled["token_ids"].squeeze(1))
sampled_token_ids = broadcast_tensor_dict(
)["sampled_token_ids"]
if len(sampled_token_ids) > 0:
sampled_token_embeds = \
self.model.get_input_embeddings(sampled_token_ids)
if self.is_driver_worker:
self.sampler.include_gpu_probs_tensor = \
orig_include_gpu_probs

output.sampled_token_embeds = sampled_token_embeds

for token_embed, sequence_group_output in zip(
output.sampled_token_embeds, output.outputs):
assert len(sequence_group_output.samples) == 1
sequence_group_output.samples[
0].output_embed = token_embed
for i, sequence_group_output in enumerate(valid_outputs):
sequence_group_output.samples[0].output_embed = \
sampled_token_embeds[i]

if not self.is_driver_worker:
return []
Expand Down