Skip to content

[BugFix]: Batch generation from prompt_embeds fails for long prompts #21390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 24, 2025

Conversation

KazusatoOoko
Copy link
Contributor

@KazusatoOoko KazusatoOoko commented Jul 22, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results

Purpose

This is a fix for the issue #21386: Batch generation from prompt_embeds fails for long prompts. An error occurs in batch generation from prompt embeddings when prompts are long.

The root cause of the issue lies in line 1948 of worker/model_runner.py, within the execute_model function of the ModelRunner class. Specifically, output.outputs may contain empty elements, whereas sampled["token_ids"]—and consequently sampled_token_embeds generated from it—only includes entries for prompts that actually produced tokens.

The most fundamental solution would be to coordinate the generation of these two structures. However, such a fix would require extensive code changes. As a lighter-weight workaround, this PR proposes using only output.outputs instead of relying on sampled["token_ids"].

See the issue #21386 for more details of the bug and identification of the cause.

Test Plan

This is a minimal example that reproduces the original bug. If this test passes, it can be considered a resolution to the issue. I also verified the fix in a more practical usage scenario on my local setup, and it appeared to work as expected.

# test.py
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
from vllm import LLM

text = "A long time ago, in a galaxy far, far away..." # this should be >= 1000 tokens

def main():
    model_name = "Qwen/Qwen3-0.6B"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    transformers_model = AutoModelForCausalLM.from_pretrained(model_name)
    embedding_layer = transformers_model.get_input_embeddings()
    llm = LLM(model=model_name, enable_prompt_embeds=True)
    token_ids_list = [tokenizer(text, return_tensors="pt")["input_ids"] for _ in range(8)]
    prompt_embeds_list = [embedding_layer(chat).squeeze(0) for chat in token_ids_list]

    outputs = llm.generate([{"prompt_embeds": embeds} for embeds in prompt_embeds_list])
    print("-" * 30)
    for i, o in enumerate(outputs):
        print(f"Generation {i + 1}: {o.outputs[0].text}\n")
    print("-" * 30)

if __name__ == "__main__":
    main()

Originally, running this code with 1 GPU CUDA_VIDIBLE_DEVICES=0 uv run python test.py resulted in the following error:

[rank0]:   File "/home/.../vllm/vllm/vllm/worker/model_runner.py", line 1952, in execute_model
[rank0]:     assert len(sequence_group_output.samples) == 1
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: AssertionError
Details of the error
INFO 07-22 14:26:50 [__init__.py:235] Automatically detected platform cuda.
INFO 07-22 14:27:12 [config.py:1593] Using max model len 40960
WARNING 07-22 14:27:13 [arg_utils.py:1696] --enable-prompt-embeds is not supported by the V1 Engine. Falling back to V0. 
WARNING 07-22 14:27:13 [arg_utils.py:1492] Chunked prefill is enabled by default for models with max_model_len > 32K. Chunked prefill might not work with some features or models. If you encounter any issues, please disable by launching with --enable-chunked-prefill=False.
INFO 07-22 14:27:14 [config.py:2414] Chunked prefill is enabled with max_num_batched_tokens=2048.
INFO 07-22 14:27:14 [llm_engine.py:230] Initializing a V0 LLM engine (v0.10.0rc2.dev33+g3779eb8c8.d20250722) with config: model='Qwen/Qwen3-0.6B', speculative_config=None, tokenizer='Qwen/Qwen3-0.6B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=40960, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=Qwen/Qwen3-0.6B, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=None, chunked_prefill_enabled=True, use_async_output_proc=True, pooler_config=None, compilation_config={"level":0,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":[],"splitting_ops":[],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"use_cudagraph":true,"cudagraph_num_of_warmups":0,"cudagraph_capture_sizes":[256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"max_capture_size":256,"local_cache_dir":null}, use_cached_outputs=False, 
INFO 07-22 14:27:16 [cuda.py:398] Using Flash Attention backend.
INFO 07-22 14:27:17 [parallel_state.py:1102] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
INFO 07-22 14:27:17 [model_runner.py:1175] Starting to load model Qwen/Qwen3-0.6B...
INFO 07-22 14:27:18 [weight_utils.py:296] Using model weights format ['*.safetensors']
INFO 07-22 14:27:18 [weight_utils.py:349] No model.safetensors.index.json found in remote.

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]

Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  4.04it/s]

Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  4.04it/s]

INFO 07-22 14:27:18 [default_loader.py:262] Loading weights took 0.28 seconds
INFO 07-22 14:27:19 [model_runner.py:1207] Model loading took 1.1201 GiB and 1.267336 seconds
INFO 07-22 14:27:20 [worker.py:296] Memory profiling takes 1.41 seconds
INFO 07-22 14:27:20 [worker.py:296] the current vLLM instance can use total_gpu_memory (79.11GiB) x gpu_memory_utilization (0.90) = 71.20GiB
INFO 07-22 14:27:20 [worker.py:296] model weights take 1.12GiB; non_torch_memory takes 0.15GiB; PyTorch activation peak memory takes 1.41GiB; the rest of the memory reserved for KV Cache is 68.51GiB.
INFO 07-22 14:27:20 [executor_base.py:115] # cuda blocks: 40090, # CPU blocks: 2340
INFO 07-22 14:27:20 [executor_base.py:120] Maximum concurrency for 40960 tokens per request: 15.66x
INFO 07-22 14:27:23 [model_runner.py:1518] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.

Capturing CUDA graph shapes: 100%|██████████| 70/70 [00:20<00:00,  3.44it/s]
INFO 07-22 14:27:43 [model_runner.py:1677] Graph capturing finished in 20 secs, took 2.31 GiB
INFO 07-22 14:27:43 [llm_engine.py:428] init engine (profile, create kv cache, warmup model) took 24.49 seconds
WARNING 07-22 14:27:43 [config.py:1517] Default sampling parameters have been overridden by the model's Hugging Face generation config recommended from the model creator. If this is not intended, please relaunch vLLM instance with `--generation-config vllm`.

Adding requests: 100%|██████████| 8/8 [00:00<00:00, 5147.17it/s]

Processed prompts:   0%|          | 0/8 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s][rank0]: Traceback (most recent call last):
[rank0]:   File "/home/.../vllm/test/test.py", line 51, in <module>
[rank0]:     main()
[rank0]:   File "/home/.../vllm/test/test.py", line 47, in main
[rank0]:     outputs = llm.generate([{"prompt_embeds": embeds} for embeds in prompt_embeds_list])
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/.../vllm/vllm/vllm/utils/__init__.py", line 1519, in inner
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/.../vllm/vllm/vllm/entrypoints/llm.py", line 516, in generate
[rank0]:     outputs = self._run_engine(use_tqdm=use_tqdm)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/.../vllm/vllm/vllm/entrypoints/llm.py", line 1734, in _run_engine
[rank0]:     step_outputs = self.llm_engine.step()
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/.../vllm/vllm/vllm/engine/llm_engine.py", line 1356, in step
[rank0]:     outputs = self.model_executor.execute_model(
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/.../vllm/vllm/vllm/executor/executor_base.py", line 148, in execute_model
[rank0]:     output = self.collective_rpc("execute_model",
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/.../vllm/vllm/vllm/executor/uniproc_executor.py", line 58, in collective_rpc
[rank0]:     answer = run_method(self.driver_worker, method, args, kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/.../vllm/vllm/vllm/utils/__init__.py", line 2990, in run_method
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/.../vllm/vllm/vllm/worker/worker_base.py", line 418, in execute_model
[rank0]:     output = self.model_runner.execute_model(
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/.../vllm/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/.../vllm/vllm/vllm/worker/model_runner.py", line 1952, in execute_model
[rank0]:     assert len(sequence_group_output.samples) == 1
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: AssertionError

Processed prompts:   0%|          | 0/8 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
[rank0]:[W722 14:27:45.514398746 ProcessGroupNCCL.cpp:1479] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

Test Result

------------------------------
Generation 1:  (omitted)

...

Generation 8:  (omitted)

------------------------------
Details
INFO 07-22 16:27:50 [__init__.py:235] Automatically detected platform cuda.
INFO 07-22 16:28:12 [config.py:1593] Using max model len 40960
WARNING 07-22 16:28:12 [arg_utils.py:1696] --enable-prompt-embeds is not supported by the V1 Engine. Falling back to V0. 
WARNING 07-22 16:28:12 [arg_utils.py:1492] Chunked prefill is enabled by default for models with max_model_len > 32K. Chunked prefill might not work with some features or models. If you encounter any issues, please disable by launching with --enable-chunked-prefill=False.
INFO 07-22 16:28:13 [config.py:2414] Chunked prefill is enabled with max_num_batched_tokens=2048.
INFO 07-22 16:28:13 [llm_engine.py:230] Initializing a V0 LLM engine (v0.10.0rc2.dev33+g3779eb8c8.d20250722) with config: model='Qwen/Qwen3-0.6B', speculative_config=None, tokenizer='Qwen/Qwen3-0.6B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=40960, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=Qwen/Qwen3-0.6B, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=None, chunked_prefill_enabled=True, use_async_output_proc=True, pooler_config=None, compilation_config={"level":0,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":[],"splitting_ops":[],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"use_cudagraph":true,"cudagraph_num_of_warmups":0,"cudagraph_capture_sizes":[256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"max_capture_size":256,"local_cache_dir":null}, use_cached_outputs=False, 
INFO 07-22 16:28:15 [cuda.py:398] Using Flash Attention backend.
INFO 07-22 16:28:16 [parallel_state.py:1102] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
INFO 07-22 16:28:16 [model_runner.py:1175] Starting to load model Qwen/Qwen3-0.6B...
INFO 07-22 16:28:16 [weight_utils.py:296] Using model weights format ['*.safetensors']
INFO 07-22 16:28:16 [weight_utils.py:349] No model.safetensors.index.json found in remote.

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]

Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  4.44it/s]

Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  4.44it/s]

INFO 07-22 16:28:17 [default_loader.py:262] Loading weights took 0.26 seconds
INFO 07-22 16:28:17 [model_runner.py:1207] Model loading took 1.1201 GiB and 0.644639 seconds
INFO 07-22 16:28:18 [worker.py:296] Memory profiling takes 0.50 seconds
INFO 07-22 16:28:18 [worker.py:296] the current vLLM instance can use total_gpu_memory (79.11GiB) x gpu_memory_utilization (0.90) = 71.20GiB
INFO 07-22 16:28:18 [worker.py:296] model weights take 1.12GiB; non_torch_memory takes 0.15GiB; PyTorch activation peak memory takes 1.41GiB; the rest of the memory reserved for KV Cache is 68.51GiB.
INFO 07-22 16:28:18 [executor_base.py:115] # cuda blocks: 40090, # CPU blocks: 2340
INFO 07-22 16:28:18 [executor_base.py:120] Maximum concurrency for 40960 tokens per request: 15.66x
INFO 07-22 16:28:20 [model_runner.py:1518] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.

Capturing CUDA graph shapes: 100%|██████████| 70/70 [00:19<00:00,  3.58it/s]
INFO 07-22 16:28:40 [model_runner.py:1677] Graph capturing finished in 20 secs, took 2.31 GiB
INFO 07-22 16:28:40 [llm_engine.py:428] init engine (profile, create kv cache, warmup model) took 22.76 seconds
WARNING 07-22 16:28:40 [config.py:1517] Default sampling parameters have been overridden by the model's Hugging Face generation config recommended from the model creator. If this is not intended, please relaunch vLLM instance with `--generation-config vllm`.

Adding requests: 100%|██████████| 8/8 [00:00<00:00, 4375.33it/s]

Processed prompts: 100%|██████████| 8/8 [00:00<00:00,  4.42it/s, est. speed input: 30083.54 toks/s, output: 531.26 toks/s]
Processed prompts: 100%|██████████| 8/8 [00:00<00:00, 33.18it/s, est. speed input: 30083.54 toks/s, output: 531.26 toks/s]
------------------------------
Generation 1:  (omitted)

...

Generation 8:  (omitted)

------------------------------
[rank0]:[W722 16:28:41.076200212 ProcessGroupNCCL.cpp:1479] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

Another fix

Another fix is to modify only the block starting at line 1948 of worker/model_runner.py, which begins with for token_embed, sequence_group_output in zip(output.sampled_token_embeds, output.outputs):. It skips the iteration if len(sequence_group_output.samples) == 0, and if len(sequence_group_output.samples) == 1, it retrieves the embedding corresponding to sequence_group_output.samples[0].output_token. This is a minimal fix to bypass the error scenario with the least amount of code change.

                    output.sampled_token_embeds = sampled_token_embeds

                    #####[Begin editing]#####
                    for sequence_group_output in output.outputs:
                        if len(sequence_group_output.samples) == 0:
                            continue
                        assert len(sequence_group_output.samples) == 1
                        assert ((sampled["token_ids"].squeeze(1) == 
                                    sequence_group_output.samples[0].output_token).any())
                        embed_idx = (sampled["token_ids"].squeeze(1) == 
                                    sequence_group_output.samples[0].output_token).nonzero(as_tuple=False)[0].item()
                        sequence_group_output.samples[0].output_embed = output.sampled_token_embeds[embed_idx, :]
                    #####[End editing]#####

        if not self.is_driver_worker:

Finally, I would like to sincerely thank all those involved in the development of vLLM, an outstanding open source project.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes a bug in batch generation with prompt_embeds by addressing a mismatch between output.outputs and output.sampled_token_ids. The suggested change avoids a larger refactoring and provides a localized solution. A high-severity comment suggests using an asynchronous utility to prevent a potential performance regression due to a synchronous host-to-device copy.

Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai>
@KazusatoOoko KazusatoOoko force-pushed the batch_prompt_embeds branch from f3ec4c0 to 8fdbbc8 Compare July 22, 2025 16:38
KazusatoOko added 3 commits July 22, 2025 16:52
Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai>
Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai>
Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai>
@KazusatoOoko KazusatoOoko changed the title [Bug]: Batch generation from prompt_embeds fails for long prompts [BugFix]: Batch generation from prompt_embeds fails for long prompts Jul 22, 2025
Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai>
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing, LGTM as long as tests pass. cc @qthequartermasterman

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) July 23, 2025 07:24
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 23, 2025
assert len(sequence_group_output.samples) == 1
sequence_group_output.samples[
0].output_embed = token_embed
torch.tensor(sampled_token_ids).to(self.device))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now token_id broadcasting is removed, does this still work for tensor parallel?

Copy link
Contributor Author

@KazusatoOoko KazusatoOoko Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for your review. Your question was very critical, and indeed, I encountered an error when running with tensor_parallel_size=2. I’ve incorporated an additional fix. I would appreciate it if you could run the tests again at your convenience.

When I executed the above test.py script with tensor_parallel_size=2 using CUDA_VISIBLE_DEVICES=0,1 uv run python test.py, I confirmed an error where the generation process stalled. The root cause wasn’t due to not broadcasting sampled["token_ids"], but rather due to a requirement I was unaware of: self.model.get_input_embeddings must be called on all GPUs. In the proposed fix, sampled_token_ids is broadcast across all ranks. This decision is based on the following rationale:

Why broadcast sampled_token_ids generated from output.outputs instead of sampled["token_ids"]?
While output.outputs is always necessary, using the information obtained from it means there is no need to broadcast sampled["token_ids"]. The original error stemmed from a mismatch between output.outputs and sampled["token_ids"], likely due to a lack of attention to the correspondence during their definition. Therefore, there is a positive reason to rely on only one of the two here. As a result, I decided to avoid using sampled["token_ids"] and to obtain every necessary information from output.outputs.

Why not just send a flag instead of broadcasting sampled_token_ids?
Fundamentally, only a 1-bit piece of information—a flag indicating whether to call self.model.get_input_embeddings—needs to be shared. However, even in non-driver ranks, a dummy variable is required to invoke self.model.get_input_embeddings. This would require at least two new variables (a flag and a dummy input). Therefore, I considered it simpler and more practical to broadcast sampled_token_ids and call self.model.get_input_embeddings with sampled_token_ids in all ranks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the fix looks good to me. However, for VocabParallelEmbedding to work the real token ids need to be broadcasted to all ranks. I don't think dummy token would work.

Copy link
Contributor Author

@KazusatoOoko KazusatoOoko Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your quick follow-up! I tested the code by forcibly using dummy tokens on non-driver workers, and indeed, it failed. I had no idea that self.model.get_input_embeddings requires the input to be identical across all workers. I added strikethrough to the incorrect part of my previous comment. Thank you for sharing such valuable knowledge!

Anyway, in my current implementation, instead of sending a flag or creating dummy tokens, I broadcast sampled_token_ids—which is constructed from output.outputs—as a substitute for sampled["token_ids"], and put it into self.model.get_input_embeddings in all ranks.

Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai>
auto-merge was automatically disabled July 23, 2025 15:27

Head branch was pushed to by a user without write access

KazusatoOko added 5 commits July 23, 2025 15:51
Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai>
Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai>
Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai>
Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai>
Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai>
@vllm-bot vllm-bot merged commit fd48d99 into vllm-project:main Jul 24, 2025
63 of 65 checks passed
DW934 pushed a commit to DW934/vllm that referenced this pull request Jul 24, 2025
…llm-project#21390)

Signed-off-by: KazusatoOko <kazusto.oko@sakana.ai>
Co-authored-by: KazusatoOko <kazusto.oko@sakana.ai>
Signed-off-by: 董巍 <dongwei@U-5XFVDYCF-2058.local>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants