Skip to content

Commit 8fdbbc8

Browse files
author
KazusatoOko
committed
[Bug]: Batch generation from prompt_embeds fails for long prompts
Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai>
1 parent 3779eb8 commit 8fdbbc8

File tree

1 file changed

+15
-16
lines changed

1 file changed

+15
-16
lines changed

vllm/worker/model_runner.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1932,24 +1932,23 @@ def execute_model(
19321932

19331933
if model_input.inputs_embeds is not None:
19341934
if self.is_driver_worker:
1935-
sampled = broadcast_tensor_dict(
1936-
{"token_ids": output.sampled_token_ids})
1937-
else:
1938-
sampled = broadcast_tensor_dict()
1939-
if sampled["token_ids"] is not None:
1940-
sampled_token_embeds = self.model.get_input_embeddings(
1941-
sampled["token_ids"].squeeze(1))
1942-
if self.is_driver_worker:
1935+
sampled_token_ids = []
1936+
valid_outputs = []
1937+
for sequence_group_output in output.outputs:
1938+
if len(sequence_group_output.samples) == 0:
1939+
continue
1940+
assert len(sequence_group_output.samples) == 1
1941+
valid_outputs.append(sequence_group_output)
1942+
sampled_token_ids.append(
1943+
sequence_group_output.samples[0].output_token)
1944+
if len(sampled_token_ids) > 0:
19431945
self.sampler.include_gpu_probs_tensor = \
19441946
orig_include_gpu_probs
1945-
1946-
output.sampled_token_embeds = sampled_token_embeds
1947-
1948-
for token_embed, sequence_group_output in zip(
1949-
output.sampled_token_embeds, output.outputs):
1950-
assert len(sequence_group_output.samples) == 1
1951-
sequence_group_output.samples[
1952-
0].output_embed = token_embed
1947+
sampled_token_embeds = self.model.get_input_embeddings(
1948+
torch.tensor(sampled_token_ids, device=self.device))
1949+
for i, sequence_group_output in enumerate(valid_outputs):
1950+
sequence_group_output.samples[0].output_embed = \
1951+
sampled_token_embeds[i]
19531952

19541953
if not self.is_driver_worker:
19551954
return []

0 commit comments

Comments
 (0)