-
-
Notifications
You must be signed in to change notification settings - Fork 8.9k
[BugFix]: Batch generation from prompt_embeds fails for long prompts #21390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix]: Batch generation from prompt_embeds fails for long prompts #21390
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request fixes a bug in batch generation with prompt_embeds
by addressing a mismatch between output.outputs
and output.sampled_token_ids
. The suggested change avoids a larger refactoring and provides a localized solution. A high-severity comment suggests using an asynchronous utility to prevent a potential performance regression due to a synchronous host-to-device copy.
Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai>
f3ec4c0
to
8fdbbc8
Compare
Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai>
Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai>
Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai>
Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing, LGTM as long as tests pass. cc @qthequartermasterman
vllm/worker/model_runner.py
Outdated
assert len(sequence_group_output.samples) == 1 | ||
sequence_group_output.samples[ | ||
0].output_embed = token_embed | ||
torch.tensor(sampled_token_ids).to(self.device)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now token_id broadcasting is removed, does this still work for tensor parallel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much for your review. Your question was very critical, and indeed, I encountered an error when running with tensor_parallel_size=2
. I’ve incorporated an additional fix. I would appreciate it if you could run the tests again at your convenience.
When I executed the above test.py script with tensor_parallel_size=2
using CUDA_VISIBLE_DEVICES=0,1 uv run python test.py
, I confirmed an error where the generation process stalled. The root cause wasn’t due to not broadcasting due to a requirement I was unaware of: sampled["token_ids"]
, but ratherself.model.get_input_embeddings
must be called on all GPUs. In the proposed fix, sampled_token_ids
is broadcast across all ranks. This decision is based on the following rationale:
Why broadcast sampled_token_ids
generated from output.outputs
instead of sampled["token_ids"]
?
While output.outputs is always necessary, using the information obtained from it means there is no need to broadcast sampled["token_ids"]
. The original error stemmed from a mismatch between output.outputs
and sampled["token_ids"]
, likely due to a lack of attention to the correspondence during their definition. Therefore, there is a positive reason to rely on only one of the two here. As a result, I decided to avoid using sampled["token_ids"]
and to obtain every necessary information from output.outputs
.
Why not just send a flag instead of broadcasting sampled_token_ids
?
Fundamentally, only a 1-bit piece of information—a flag indicating whether to call self.model.get_input_embeddings
—needs to be shared. However, even in non-driver ranks, a dummy variable is required to invoke self.model.get_input_embeddings
. This would require at least two new variables (a flag and a dummy input). Therefore, I considered it simpler and more practical to broadcast sampled_token_ids
and call self.model.get_input_embeddings
with sampled_token_ids
in all ranks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the fix looks good to me. However, for VocabParallelEmbedding
to work the real token ids need to be broadcasted to all ranks. I don't think dummy token would work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your quick follow-up! I tested the code by forcibly using dummy tokens on non-driver workers, and indeed, it failed. I had no idea that self.model.get_input_embeddings
requires the input to be identical across all workers. I added strikethrough to the incorrect part of my previous comment. Thank you for sharing such valuable knowledge!
Anyway, in my current implementation, instead of sending a flag or creating dummy tokens, I broadcast sampled_token_ids
—which is constructed from output.outputs
—as a substitute for sampled["token_ids"]
, and put it into self.model.get_input_embeddings
in all ranks.
Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai>
Head branch was pushed to by a user without write access
Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai>
Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai>
Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai>
Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai>
Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai>
…llm-project#21390) Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai> Co-authored-by: KazusatoOko <kazusto.oko@sakana.ai> Signed-off-by: 董巍 <dongwei@U-5XFVDYCF-2058.local>
Essential Elements of an Effective PR Description Checklist
Purpose
This is a fix for the issue #21386: Batch generation from prompt_embeds fails for long prompts. An error occurs in batch generation from prompt embeddings when prompts are long.
The root cause of the issue lies in line 1948 of
worker/model_runner.py
, within theexecute_model
function of theModelRunner
class. Specifically,output.outputs
may contain empty elements, whereassampled["token_ids"]
—and consequentlysampled_token_embeds
generated from it—only includes entries for prompts that actually produced tokens.The most fundamental solution would be to coordinate the generation of these two structures. However, such a fix would require extensive code changes. As a lighter-weight workaround, this PR proposes using only output.outputs instead of relying on sampled["token_ids"].
See the issue #21386 for more details of the bug and identification of the cause.
Test Plan
This is a minimal example that reproduces the original bug. If this test passes, it can be considered a resolution to the issue. I also verified the fix in a more practical usage scenario on my local setup, and it appeared to work as expected.
Originally, running this code with 1 GPU
CUDA_VIDIBLE_DEVICES=0 uv run python test.py
resulted in the following error:Details of the error
Test Result
Details
Another fix
Another fix is to modify only the block starting at line 1948 of
worker/model_runner.py
, which begins withfor token_embed, sequence_group_output in zip(output.sampled_token_embeds, output.outputs):
. It skips the iteration iflen(sequence_group_output.samples) == 0
, and iflen(sequence_group_output.samples) == 1
, it retrieves the embedding corresponding tosequence_group_output.samples[0].output_token
. This is a minimal fix to bypass the error scenario with the least amount of code change.Finally, I would like to sincerely thank all those involved in the development of vLLM, an outstanding open source project.