Skip to content

[Frontend] Add chunked processing to handle long inputs in embedding models #20837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

x22x22
Copy link

@x22x22 x22x22 commented Jul 11, 2025

…g, and update relevant documentation and examples. New example scripts and service startup scripts are added to demonstrate how to configure and utilize chunking processing. Update the model configuration to support long - text processing and implement the chunking processing logic in the code.

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Add chunked processing support for long text embeddings to resolve CUDA crashes when input text exceeds model's maximum context length.

Problem Solved

  • CUDA crashes: vLLM embedding service crashes when processing text longer than max_model_len
  • Limited input length: No native support for handling arbitrarily long text in embedding models
  • Memory constraints: Large inputs cause out-of-memory errors during embedding generation

Solution

This PR implements automatic chunked processing at the serving layer that:

  • ✅ Automatically detects when input exceeds model limits
  • ✅ Splits long text into manageable chunks at token boundaries
  • ✅ Processes each chunk independently to avoid memory issues
  • ✅ Aggregates results using FastChat-style weighted averaging
  • ✅ Maintains backward compatibility for short text inputs
  • ✅ Requires zero changes to existing model implementations

Key Features

  • Zero model code modification: All logic implemented in serving layer
  • Configurable: Enabled via enable_chunked_processing: true in pooler config
  • Smart aggregation: Token count-based weighted averaging preserves semantic quality
  • Production ready: Comprehensive error handling and logging

Supported Models

  • intfloat/multilingual-e5-large (initially)
  • Extensible architecture for other embedding models

This enables vLLM to handle embedding requests of any length without crashes, significantly expanding its utility for RAG applications and long document processing.

Test Plan

Long Text Embedding with Chunked Processing

Test Result

Before modification

  • serve
ERROR 07-12 02:52:36 [engine.py:165] RuntimeError('CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasGemmEx( handle, opa, opb, m, n, k, alpha_ptr, a, CUDA_R_16F, lda, b, CUDA_R_16F, ldb, beta_ptr, c, CUDA_R_16F, ldc, compute_type, CUBLAS
_GEMM_DEFAULT_TENSOR_OP)`')
ERROR 07-12 02:52:36 [engine.py:165] Traceback (most recent call last):
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/llm/vllm-250711/vllm/engine/multiprocessing/engine.py", line 163, in start 
ERROR 07-12 02:52:36 [engine.py:165]     self.run_engine_loop()
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/llm/vllm-250711/vllm/engine/multiprocessing/engine.py", line 226, in run_engine_loop
ERROR 07-12 02:52:36 [engine.py:165]     request_outputs = self.engine_step()
ERROR 07-12 02:52:36 [engine.py:165]                       ^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/llm/vllm-250711/vllm/engine/multiprocessing/engine.py", line 252, in engine_step
ERROR 07-12 02:52:36 [engine.py:165]     raise e
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/llm/vllm-250711/vllm/engine/multiprocessing/engine.py", line 235, in engine_step
ERROR 07-12 02:52:36 [engine.py:165]     return self.engine.step()
ERROR 07-12 02:52:36 [engine.py:165]            ^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/llm/vllm-250711/vllm/engine/llm_engine.py", line 1356, in step
ERROR 07-12 02:52:36 [engine.py:165]     outputs = self.model_executor.execute_model(
ERROR 07-12 02:52:36 [engine.py:165]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/llm/vllm-250711/vllm/executor/executor_base.py", line 141, in execute_model
ERROR 07-12 02:52:36 [engine.py:165]     output = self.collective_rpc("execute_model",
ERROR 07-12 02:52:36 [engine.py:165]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/llm/vllm-250711/vllm/executor/uniproc_executor.py", line 57, in collective_rpc
ERROR 07-12 02:52:36 [engine.py:165]     answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 07-12 02:52:36 [engine.py:165]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/llm/vllm-250711/vllm/utils/__init__.py", line 2943, in run_method
ERROR 07-12 02:52:36 [engine.py:165]     return func(*args, **kwargs)
ERROR 07-12 02:52:36 [engine.py:165]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/llm/vllm-250711/vllm/worker/worker_base.py", line 420, in execute_model
ERROR 07-12 02:52:36 [engine.py:165]     output = self.model_runner.execute_model(
ERROR 07-12 02:52:36 [engine.py:165]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/conda/envs/vllm-250411/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 07-12 02:52:36 [engine.py:165]     return func(*args, **kwargs)
ERROR 07-12 02:52:36 [engine.py:165]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/llm/vllm-250711/vllm/worker/pooling_model_runner.py", line 119, in execute_model
ERROR 07-12 02:52:36 [engine.py:165]     hidden_or_intermediate_states = model_executable(
ERROR 07-12 02:52:36 [engine.py:165]                                     ^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/conda/envs/vllm-250411/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 07-12 02:52:36 [engine.py:165]     return self._call_impl(*args, **kwargs)
ERROR 07-12 02:52:36 [engine.py:165]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/conda/envs/vllm-250411/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 07-12 02:52:36 [engine.py:165]     return forward_call(*args, **kwargs)
ERROR 07-12 02:52:36 [engine.py:165]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/llm/vllm-250711/vllm/model_executor/models/bert.py", line 415, in forward
ERROR 07-12 02:52:36 [engine.py:165]     return self.model(input_ids=input_ids,
ERROR 07-12 02:52:36 [engine.py:165]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/conda/envs/vllm-250411/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 07-12 02:52:36 [engine.py:165]     return self._call_impl(*args, **kwargs)
ERROR 07-12 02:52:36 [engine.py:165]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/conda/envs/vllm-250411/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 07-12 02:52:36 [engine.py:165]     return forward_call(*args, **kwargs)
ERROR 07-12 02:52:36 [engine.py:165]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/llm/vllm-250711/vllm/model_executor/models/bert.py", line 350, in forward
ERROR 07-12 02:52:36 [engine.py:165]     return self.encoder(hidden_states)
ERROR 07-12 02:52:36 [engine.py:165]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/llm/vllm-250711/vllm/compilation/decorators.py", line 246, in __call__
ERROR 07-12 02:52:36 [engine.py:165]     model_output = self.forward(*args, **kwargs)
ERROR 07-12 02:52:36 [engine.py:165]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/llm/vllm-250711/vllm/model_executor/models/bert.py", line 114, in forward
ERROR 07-12 02:52:36 [engine.py:165]     def forward(
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/conda/envs/vllm-250411/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 07-12 02:52:36 [engine.py:165]     return self._call_impl(*args, **kwargs)
ERROR 07-12 02:52:36 [engine.py:165]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/conda/envs/vllm-250411/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 07-12 02:52:36 [engine.py:165]     return forward_call(*args, **kwargs)
ERROR 07-12 02:52:36 [engine.py:165]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/conda/envs/vllm-250411/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
ERROR 07-12 02:52:36 [engine.py:165]     return fn(*args, **kwargs)
ERROR 07-12 02:52:36 [engine.py:165]            ^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/conda/envs/vllm-250411/lib/python3.11/site-packages/torch/fx/graph_module.py", line 830, in call_wrapped
ERROR 07-12 02:52:36 [engine.py:165]     return self._wrapped_call(self, *args, **kwargs)
ERROR 07-12 02:52:36 [engine.py:165]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/conda/envs/vllm-250411/lib/python3.11/site-packages/torch/fx/graph_module.py", line 406, in __call__
ERROR 07-12 02:52:36 [engine.py:165]     raise e
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/conda/envs/vllm-250411/lib/python3.11/site-packages/torch/fx/graph_module.py", line 393, in __call__
ERROR 07-12 02:52:36 [engine.py:165]     return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
ERROR 07-12 02:52:36 [engine.py:165]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/conda/envs/vllm-250411/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 07-12 02:52:36 [engine.py:165]     return self._call_impl(*args, **kwargs)
ERROR 07-12 02:52:36 [engine.py:165]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/conda/envs/vllm-250411/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 07-12 02:52:36 [engine.py:165]     return forward_call(*args, **kwargs)
ERROR 07-12 02:52:36 [engine.py:165]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "<eval_with_key>.2", line 294, in forward
ERROR 07-12 02:52:36 [engine.py:165]     submod_0 = self.submod_0(l_hidden_states_,...l_self_modules_layer_module
s_23_modules_output_modules_layer_norm_parameters_bias_ = None
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/llm/vllm-250711/vllm/compilation/cuda_piecewise_backend.py", line 117, in __call__
ERROR 07-12 02:52:36 [engine.py:165]     return self.compiled_graph_for_general_shape(*args)
ERROR 07-12 02:52:36 [engine.py:165]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/conda/envs/vllm-250411/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 2143, in wrapper
ERROR 07-12 02:52:36 [engine.py:165]     return pytree.tree_unflatten(compiled_fn(*args, **kwargs), spec)
ERROR 07-12 02:52:36 [engine.py:165]                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/conda/envs/vllm-250411/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
ERROR 07-12 02:52:36 [engine.py:165]     return fn(*args, **kwargs)
ERROR 07-12 02:52:36 [engine.py:165]            ^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/conda/envs/vllm-250411/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1201, in forward
ERROR 07-12 02:52:36 [engine.py:165]     return compiled_fn(full_args)
ERROR 07-12 02:52:36 [engine.py:165]            ^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/conda/envs/vllm-250411/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 328, in runtime_wrapper
ERROR 07-12 02:52:36 [engine.py:165]     all_outs = call_func_at_runtime_with_args(
ERROR 07-12 02:52:36 [engine.py:165]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/conda/envs/vllm-250411/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
ERROR 07-12 02:52:36 [engine.py:165]     out = normalize_as_list(f(args))
ERROR 07-12 02:52:36 [engine.py:165]                             ^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/conda/envs/vllm-250411/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 689, in inner_fn
ERROR 07-12 02:52:36 [engine.py:165]     outs = compiled_fn(args)
ERROR 07-12 02:52:36 [engine.py:165]            ^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/conda/envs/vllm-250411/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 495, in wrapper
ERROR 07-12 02:52:36 [engine.py:165]     return compiled_fn(runtime_args)
ERROR 07-12 02:52:36 [engine.py:165]            ^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/conda/envs/vllm-250411/lib/python3.11/site-packages/torch/_inductor/output_code.py", line 460, in __call__
ERROR 07-12 02:52:36 [engine.py:165]     return self.current_callable(inputs)
ERROR 07-12 02:52:36 [engine.py:165]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/data/conda/envs/vllm-250411/lib/python3.11/site-packages/torch/_inductor/utils.py", line 2404, in run
ERROR 07-12 02:52:36 [engine.py:165]     return model(new_inputs)
ERROR 07-12 02:52:36 [engine.py:165]            ^^^^^^^^^^^^^^^^^
ERROR 07-12 02:52:36 [engine.py:165]   File "/hs_data/.cache/vllm/torch_compile_cache/12188d34d2/rank_0_0/inductor_cache/xq/cxqsnh7zlyb6wqrdkusizoacfp34wawoczfn2qrddhljgmde7x2e.py", line 520, in call
ERROR 07-12 02:52:36 [engine.py:165]     extern_kernels.mm(reinterpret_tensor(buf1, (s0, 1024), (1024, 1), 0), reinterpret_tensor(arg4_1, (1024, 1024), (1, 1024), 0), out=buf4)
ERROR 07-12 02:52:36 [engine.py:165] RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasGemmEx( handle, opa, opb, m, n, k, alpha_ptr, a, CUDA_R_16F, lda, b, CUDA_R_16F, ldb, beta_ptr, c, CUDA_R_16F, ldc, compute_type, CUBLAS
_GEMM_DEFAULT_TENSOR_OP)`
[rank0]:[W712 02:52:37.923419125 ProcessGroupNCCL.cpp:1476] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (f
unction operator())
INFO:     Shutting down
INFO:     Waiting for application shutdown.
INFO:     Application shutdown complete.
INFO:     Finished server process [2509407]

After modification

  • serve
INFO 07-12 02:20:40 [logger.py:43] Received request embd-b02f362e260a4e218c570cc6ab1fb346-chunk-0: prompt: '', params: PoolingParams(dimensions=None, use_cross_encoder=False, additional_metadata=None), prompt_token_ids: [0, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034], prompt_embeds shape: None, lora_request: None, prompt_adapter_request: None.
INFO 07-12 02:20:40 [logger.py:43] Received request embd-b02f362e260a4e218c570cc6ab1fb346-chunk-1: prompt: '', params: PoolingParams(dimensions=None, use_cross_encoder=False, additional_metadata=None), prompt_token_ids: [7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433], prompt_embeds shape: None, lora_request: None, prompt_adapter_request: None.
INFO 07-12 02:20:40 [logger.py:43] Received request embd-b02f362e260a4e218c570cc6ab1fb346-chunk-2: prompt: '', params: PoolingParams(dimensions=None, use_cross_encoder=False, additional_metadata=None), prompt_token_ids: [214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215], prompt_embeds shape: None, lora_request: None, prompt_adapter_request: None.
INFO 07-12 02:20:40 [logger.py:43] Received request embd-b02f362e260a4e218c570cc6ab1fb346-chunk-3: prompt: '', params: PoolingParams(dimensions=None, use_cross_encoder=False, additional_metadata=None), prompt_token_ids: [6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 39215, 6892, 2408, 3034, 7986, 100, 7839, 12225, 9433, 214, 44622, 1363, 5, 2], prompt_embeds shape: None, lora_request: None, prompt_adapter_request: None.
INFO 07-12 02:20:40 [engine.py:317] Added request embd-b02f362e260a4e218c570cc6ab1fb346-chunk-0.
INFO 07-12 02:20:40 [engine.py:317] Added request embd-b02f362e260a4e218c570cc6ab1fb346-chunk-1.
INFO 07-12 02:20:40 [engine.py:317] Added request embd-b02f362e260a4e218c570cc6ab1fb346-chunk-2.
  • client
# python ./examples/online_serving/openai_embedding_long_text_client.py
🚀 vLLM Long Text Embedding Client
📡 Connecting to: http://localhost:31090/v1
🤖 Model: multilingual-e5-large
🔑 API Key: ********-key
🧪 Testing vLLM Long Text Embedding with Chunked Processing
======================================================================

📝 Test 1: Short Text
Text length: 42 characters
✅ Success!
   - Embedding dimension: 1024
   - Processing time: 0.54s
   - Expected chunks: ~1
   - First 5 values: [0.01232257578521967, 0.009728744626045227, -0.014059314504265785, -0.03867439180612564, 0.037110574543476105]

📝 Test 2: Medium Text
Text length: 3200 characters
✅ Success!
   - Embedding dimension: 1024
   - Processing time: 0.04s
   - Expected chunks: ~1
   - First 5 values: [0.04108031839132309, -0.009568133391439915, -0.028527623042464256, -0.04032902047038078, 0.020682798698544502]

📝 Test 3: Long Text (2 chunks)
Text length: 27250 characters
✅ Success!
   - Embedding dimension: 1024
   - Processing time: 0.07s
   - Expected chunks: ~2
   - First 5 values: [0.04508449137210846, -0.017967931926250458, -0.014230169355869293, -0.03835897892713547, 0.003280746517702937]

📝 Test 4: Very Long Text (3+ chunks)
Text length: 88000 characters
✅ Success!
   - Embedding dimension: 1024
   - Processing time: 0.16s
   - Expected chunks: ~3
   - First 5 values: [0.03270554542541504, 0.0007968051359057426, -0.016265524551272392, -0.03590775281190872, -0.009043066762387753]

🔄 Testing Batch Embedding with Mixed Lengths
==================================================
✅ Batch processing successful!
   - Number of inputs: 4
   - Number of embeddings: 4
   - Total processing time: 0.08s
   - Average time per input: 0.02s
   - Input 1: 12 chars → 1024D embedding
   - Input 2: 860 chars → 1024D embedding
   - Input 3: 18 chars → 1024D embedding
   - Input 4: 20000 chars → 1024D embedding

🔍 Testing Embedding Consistency
========================================
   - Generated embedding 1
   - Generated embedding 2
   - Generated embedding 3
✅ Consistency test completed!
   - Cosine similarity between runs: 1.000000
   - Expected: ~1.0 (identical embeddings)
   - ✅ High consistency achieved!

======================================================================
🎉 All tests completed!

💡 Key Features Demonstrated:
   - ✅ Automatic chunked processing for long text
   - ✅ Seamless handling of mixed-length batches
   - ✅ Consistent embedding generation
   - ✅ Backward compatibility with short text

📚 For more information, see:
   - Documentation: https://docs.vllm.ai/en/latest/models/pooling_models.html
   - Chunked Processing Guide: openai_embedding_long_text.md

(Optional) Documentation Update

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @x22x22, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant enhancement to vLLM's embedding capabilities by adding support for chunked processing of long text inputs. This feature directly addresses issues like CUDA crashes and memory constraints when processing text longer than a model's max_model_len, allowing vLLM to generate embeddings for arbitrarily long documents without requiring changes to the underlying models. The solution is implemented at the serving layer, providing a configurable and robust mechanism for handling large inputs while maintaining embedding quality through smart aggregation.

Highlights

  • New Feature: Chunked Processing for Embedding Models: Introduced automatic chunked processing at the serving layer for embedding models. This enables vLLM to handle text inputs that exceed the model's maximum context length by splitting them into manageable chunks, processing each independently, and aggregating the results.
  • Configuration and Activation: The chunked processing feature is configurable via the PoolerConfig by setting enable_chunked_processing: true. It automatically detects when input exceeds max_model_len and triggers the chunking logic.
  • Intelligent Aggregation: Implemented a FastChat-style weighted averaging algorithm to combine embeddings from multiple chunks. This method uses token counts as weights, ensuring that longer chunks contribute proportionally more to the final aggregated embedding, preserving semantic quality.
  • Backward Compatibility and Model Support: The implementation maintains backward compatibility for short text inputs and requires zero modifications to existing model code. Initially, intfloat/multilingual-e5-large is explicitly supported, with an extensible architecture for other embedding models.
  • Documentation and Examples: Added comprehensive documentation detailing the feature, its configuration, how it works, performance characteristics, and limitations. New example scripts (server and client) are provided to demonstrate how to configure and utilize chunking processing for long text embeddings.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added documentation Improvements or additions to documentation frontend labels Jul 11, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a valuable feature for handling long text embeddings by implementing chunked processing. The code is generally well-written, and the documentation and examples are thorough. I've identified a critical bug in the batch processing logic that could lead to incorrect results when multiple long prompts are sent in a single request. I've also provided several suggestions to improve code quality, maintainability, and performance. Once the critical issue is addressed, this will be a great addition to the project.

…g, and update relevant documentation and examples. New example scripts and service startup scripts are added to demonstrate how to configure and utilize chunking processing. Update the model configuration to support long - text processing and implement the chunking processing logic in the code.

Signed-off-by: x22x22 <wadeking@qq.com>
@x22x22 x22x22 force-pushed the feat/support-long-text-embedding branch from b5f245d to 5398bbd Compare July 11, 2025 19:01
x22x22 added 2 commits July 12, 2025 03:21
… with isort, and ensure the accuracy of docstrings.

Signed-off-by: x22x22 <wadeking@qq.com>
…ompts, and improve the implementation of chunk processing to ensure accuracy and efficiency when handling long texts. Meanwhile, relevant type annotations have been updated to enhance code readability and type safety.

Signed-off-by: x22x22 <wadeking@qq.com>
@x22x22 x22x22 changed the title [Core] Add chunked processing to handle long inputs in embedding models [Frontend] Add chunked processing to handle long inputs in embedding models Jul 11, 2025
x22x22 added 5 commits July 12, 2025 04:06
…ess of block IDs and fix the block ID conflicts in batch processing. Updated relevant examples to demonstrate the new features.

Signed-off-by: x22x22 <wadeking@qq.com>
…ess of block IDs and fix the block ID conflicts in batch processing. Updated relevant examples to demonstrate the new features.

Signed-off-by: x22x22 <wadeking@qq.com>
…f the "Slow Processing" section from 1 to 3 to ensure the accuracy and consistency of the list.

Signed-off-by: x22x22 <wadeking@qq.com>
…_CODE to enhance the flexibility of the model name, and use this variable to replace the hard - coded model name in the output information. Ensure that the configuration during service startup is more consistent and maintainable.

Signed-off-by: x22x22 <wadeking@qq.com>
…verify the uniqueness of block IDs and resolve the block ID conflict issues in batch processing. Meanwhile, relevant documents and examples have been updated to ensure the accuracy and consistency of long - text processing.

Signed-off-by: x22x22 <wadeking@qq.com>
@DarkLight1337
Copy link
Member

cc @maxdebayser @22quinn @noooop

@noooop
Copy link
Contributor

noooop commented Jul 12, 2025

In fact, embedding models are not very suitable for handling extremely long inputs, as too much content can lead to embeddings that are not able to effectively distinguish between similar content.

Here's a simple way to confirm that automatic chunked processing is working effectively:

Reference mteb_test_embed_models in vllm/tests/models/language/pooling
/mteb_utils.py . and https://github.com/noooop/snippet/blob/main/benchmarks/test_mteb/test_speed.py

Keeping only the very front part of long context, such as 2048 or even 512, is an extremely high baseline.
Refer to LongEmbed: Extending Embedding Models for Long Context Retrieval

However, it still suffers from biased distribution of key information, as demonstratedin Figure 2. With only 512 context length, E5Base achieves >85% nDCG scores on 3 out of 8 publicly available LoCo tasks.

Do the following three comparative experiments

  • max_model_len = 2048
  • max_model_len =8102
  • max_model_len = 2048 + automatic chunked processing

If automatic chunked processing using multilingual-e5-large on mteb/T2Reranking dataset(or any test with a context exceeding 8K), can achieve comparable results indicates that automatic chunked processing is effective

@x22x22
Copy link
Author

x22x22 commented Jul 12, 2025

In fact, embedding models are not very suitable for handling extremely long inputs, as too much content can lead to embeddings that are not able to effectively distinguish between similar content.

Here's a simple way to confirm that automatic chunked processing is working effectively:

Reference mteb_test_embed_models in vllm/tests/models/language/pooling /mteb_utils.py . and https://github.com/noooop/snippet/blob/main/benchmarks/test_mteb/test_speed.py

Keeping only the very front part of long context, such as 2048 or even 512, is an extremely high baseline. Refer to LongEmbed: Extending Embedding Models for Long Context Retrieval

However, it still suffers from biased distribution of key information, as demonstratedin Figure 2. With only 512 context length, E5Base achieves >85% nDCG scores on 3 out of 8 publicly available LoCo tasks.

Do the following three comparative experiments

  • max_model_len = 2048
  • max_model_len =8102
  • max_model_len = 2048 + automatic chunked processing

If automatic chunked processing using multilingual-e5-large on mteb/T2Reranking dataset(or any test with a context exceeding 8K), can achieve comparable results indicates that automatic chunked processing is effective

@noooop I've manually tested using text chunks exceeding 1,000 tokens in vector databases, and confirmed that short user queries or task descriptions (~100 tokens) can successfully retrieve relevant text fragments.

While this verification isn't scientifically rigorous, it demonstrates a viable practical solution. I'll allocate time later to run the benchmark tests you recommended - appreciate the suggestion.

@noooop
Copy link
Contributor

noooop commented Jul 13, 2025

@x22x22

After some investigation, intfloat/multilingual-e5-large uses the classic BERT architecture with a context length of 512, which appears very weak in 2025. Please perform a comparative test using jina-embeddings-v3, which has a maximum context length of 8192 and uses mean pooling.

Unless you use VLLM_ALLOW_LONG_MAX_MODEL_LEN or similar, you Should Not Allow set the context of intfloat/multilingual-e5-large beyond 512, as it will exceed position_embeddings and cause an out-of-bounds error. It is not a bug. Please weaken or remove the content related to CUDA crashes.

@x22x22
Copy link
Author

x22x22 commented Jul 13, 2025

@x22x22

After some investigation, intfloat/multilingual-e5-large uses the classic BERT architecture with a context length of 512, which appears very weak in 2025. Please perform a comparative test using jina-embeddings-v3, which has a maximum context length of 8192 and uses mean pooling.

Unless you use VLLM_ALLOW_LONG_MAX_MODEL_LEN or similar, you Should Not Allow set the context of intfloat/multilingual-e5-large beyond 512, as it will exceed position_embeddings and cause an out-of-bounds error. It is not a bug. Please weaken or remove the content related to CUDA crashes.

@noooop
This enhancement specifically leverages VLLM_ALLOW_LONG_MAX_MODEL_LEN, and you can see the corresponding launch code in my test script here:
https://github.com/vllm-project/vllm/blob/da812672715ac5bb09a4e5e4acb1d6d2d59feca7/examples/online_serving/openai_embedding_long_text_service.sh

The purpose is to enable models like multilingual-e5-large to support longer contexts through sharding without modifying the model's original code. The same principle applies to other embedding models - for example, if you want jina-embeddings-v3 to support beyond its native 8192 context length, simply adjusting the MAX_MODEL_LEN parameter would achieve this.

While this approach may not deliver optimal embedding performance, it provides a practical low-cost solution for RAG scenarios requiring simultaneous processing of both short and long texts. Crucially, no performance penalty occurs when input stays within a model's native context limit (e.g. ≤512 for E5, ≤8192 for Jina), as no special chunking gets triggered.


Would you be open to continuing this discussion more efficiently via https://slack.vllm.ai? I've requested access to the Slack workspace but haven't received approval yet - perhaps we could connect there once I'm onboarded.

@noooop
Copy link
Contributor

noooop commented Jul 13, 2025

I looked through the code carefully.

You can add a new parameter such as max_embed_len, but do not modify any code related to max_model_len, That will cause a huge number of bugs.

And do not use VLLM_ALLOW_LONG_MAX_MODEL_LEN.

I think we should remove VLLM_ALLOW_LONG_MAX_MODEL_LEN. I can’t think of any use case that would require this flag.

@x22x22
Copy link
Author

x22x22 commented Jul 13, 2025

I looked through the code carefully. You can add a new parameter such as max_embed_len, but do not modify any code related to max_model_len, and do not use VLLM_ALLOW_LONG_MAX_MODEL_LEN. That will cause a huge number of bugs

@noooop
Understood - I'll modify the code tomorrow following your guidance. Instead of the VLLM_ALLOW_LONG_MAX_MODEL_LEN approach, I'll implement a dedicated max_embed_len parameter to handle extended context lengths for embedding models. This will avoid any interference with the core max_model_len logic and prevent potential side effects.

Regarding communication, would you be open to continuing this discussion through a more efficient channel? I'd appreciate if we could connect either via:

  1. https://slack.vllm.ai (I'm still awaiting access approval), or
  2. WeChat if that's more convenient

Would either of these options work better for real-time collaboration? Thank you for your guidance on this implementation!

@noooop
Copy link
Contributor

noooop commented Jul 13, 2025

Regarding communication, would you be open to continuing this discussion through a more efficient channel? I'd appreciate if we could connect either via:

  1. https://slack.vllm.ai (I'm still awaiting access approval), or
  2. WeChat if that's more convenient

I’m extremely socially anxious.

@x22x22
Copy link
Author

x22x22 commented Jul 13, 2025

Regarding communication, would you be open to continuing this discussion through a more efficient channel? I'd appreciate if we could connect either via:

  1. https://slack.vllm.ai (I'm still awaiting access approval), or
  2. WeChat if that's more convenient

I’m extremely socially anxious.

@noooop

I completely understand, I also have social anxiety. This way of communicating is pretty good too 😄

I'll modify the code according to your suggestions, expecting to have it done by tomorrow~ If there's anything else I need to pay attention to, please feel free to communicate anytime, thank you!

…en` parameter, enabling long - text input without the need to set the environment variable `VLLM_ALLOW_LONG_MAX_MODEL_LEN`. Modify the relevant configurations and processing logic to ensure clear error messages are provided when the input exceeds the maximum embedding length, while maintaining backward compatibility. Enhance the description of input validation and processing performance.

Signed-off-by: x22x22 <wadeking@qq.com>
@x22x22
Copy link
Author

x22x22 commented Jul 13, 2025

@noooop
I've addressed both concerns:

  1. I've removed the dependency on VLLM_ALLOW_LONG_MAX_MODEL_LEN
  2. Instead of modifying max_model_len, we now configure max_embed_len through the --override-pooler-config parameter as follows:
{
  "pooling_type": "CLS",
  "normalize": true,
  "enable_chunked_processing": true,
  "max_embed_len": 10240
}

… ensuring the correctness of the configuration when dealing with long - text inputs. Adjust the format of the relevant configuration strings to better handle the embedding length limit.

Signed-off-by: x22x22 <wadeking@qq.com>
@maxdebayser
Copy link
Contributor

@x22x22 , please correct me if I'm wrong, but it seems that the aggregation is based on the assumptions that taking the mean of the embedding chunks would be correct:

aggregated_embedding = weighted_sum / total_weight

To work two requirements are necessary:

  1. That the pooling type is MEAN (and not CLS, LAST or others)
  2. That the model uses a causal attention mask so that the tokens can only attend to previous tokens.

However, the BERT-type models don't satisfy the second requirement.

As @noooop mentioned, there are newer models that are decoder models. For these models, if the pooling type is LAST, we already support chunked prefill.

I think we should remove VLLM_ALLOW_LONG_MAX_MODEL_LEN. I can’t think of any use case that would require this flag.

@noooop , this var is useful for testing or bypassing restrictions of misconfigured models.

…ced chunk processing functionality. The logic for automatic detection and verification of pooling types has been optimized to ensure warnings are provided when non - MEAN pooling types are used. The relevant configurations and processing logic have been updated to improve user experience and compatibility.

Signed-off-by: x22x22 <wadeking@qq.com>
@x22x22
Copy link
Author

x22x22 commented Jul 14, 2025

@x22x22 , please correct me if I'm wrong, but it seems that the aggregation is based on the assumptions that taking the mean of the embedding chunks would be correct:

aggregated_embedding = weighted_sum / total_weight

To work two requirements are necessary:

  1. That the pooling type is MEAN (and not CLS, LAST or others)
  2. That the model uses a causal attention mask so that the tokens can only attend to previous tokens.

However, the BERT-type models don't satisfy the second requirement.
@maxdebayser

Thank you for your feedback - you're absolutely right! I've updated the implementation to use MEAN pooling and made the following improvements:

Correct pooling configuration: multilingual-e5-large now uses MEAN pooling by default

Automatic detection: Added support for automatic configuration of various popular models

Manual specification: Users can manually specify the pooling type with helpful prompts and guidance

Safety warnings: Users are now alerted about potential impacts when using non-MEAN pooling methods

Flexible configuration: All parameters can be customized through environment variables

This ensures users can safely utilize the chunked processing functionality without worrying about pooling type mismatches!

For usage reference, please check out the new example startup script:
https://github.com/vllm-project/vllm/blob/a5432ac40c23dcbeba8ce3bb6af4084591dd0f47/examples/online_serving/openai_embedding_long_text_service.sh

Your point about BERT-type models not satisfying the causal attention requirement is particularly important. The weighted aggregation approach works best with models that use mean pooling and have the appropriate attention patterns. The automatic detection and safety warnings should help users avoid potential issues with incompatible model architectures.

@maxdebayser
Copy link
Contributor

Yes, but intfloat/multilingual-e5-large is a Roberta mode. So the attention type is not causal which means that the attention has to be computed on the full sequence as once. Otherwise you're not going to get correct results.

@x22x22
Copy link
Author

x22x22 commented Jul 14, 2025

In fact, embedding models are not very suitable for handling extremely long inputs, as too much content can lead to embeddings that are not able to effectively distinguish between similar content.

Here's a simple way to confirm that automatic chunked processing is working effectively:

Reference mteb_test_embed_models in vllm/tests/models/language/pooling /mteb_utils.py . and https://github.com/noooop/snippet/blob/main/benchmarks/test_mteb/test_speed.py

Keeping only the very front part of long context, such as 2048 or even 512, is an extremely high baseline. Refer to LongEmbed: Extending Embedding Models for Long Context Retrieval

However, it still suffers from biased distribution of key information, as demonstratedin Figure 2. With only 512 context length, E5Base achieves >85% nDCG scores on 3 out of 8 publicly available LoCo tasks.

Do the following three comparative experiments

  • max_model_len = 2048
  • max_model_len =8102
  • max_model_len = 2048 + automatic chunked processing

If automatic chunked processing using multilingual-e5-large on mteb/T2Reranking dataset(or any test with a context exceeding 8K), can achieve comparable results indicates that automatic chunked processing is effective

@noooop

Based on your suggestion, I adapted https://github.com/dwzhu-pku/LongEmbed to the OpenAI SDK interface and obtained the following benchmark results:

Model: multilingual-e5-large

Branch: feat/support-long-text-embedding

Parameter Set 1

  • max_input_tokens: 3072000
  • window_length_list: 256 512 1024 2048 4096 8192
{
  "LEMBNeedleRetrieval": {
    "256": 0.74,
    "512": 0.64,
    "1024": 0.68,
    "2048": 0.52,
    "4096": 0.5,
    "8192": 0.26,
    "16384": 0.34,
    "32768": 0.3
  },
  "LEMBPasskeyRetrieval": {
    "256": 1.0,
    "512": 1.0,
    "1024": 0.42,
    "2048": 0.64,
    "4096": 0.14,
    "8192": 0.34,
    "16384": 0.1,
    "32768": 0.14
  }
}

Parameter Set 2

  • max_input_tokens: 450
  • window_length_list: 256 512
{
  "LEMBNeedleRetrieval": {
    "256": 0.74,
    "512": 0.72
  },
  "LEMBPasskeyRetrieval": {
    "256": 1.0,
    "512": 0.94
  }
}

Branch: main

  • max_input_tokens: 450
  • window_length_list: 256 512
{
  "LEMBNeedleRetrieval": {
    "256": 0.74,
    "512": 0.72
  },
  "LEMBPasskeyRetrieval": {
    "256": 1.0,
    "512": 0.94
  }
}

Model: jina-embeddings-v3

Branch: feat/support-long-text-embedding

  • max_input_tokens: 3072000
  • window_length_list: 256 512 1024 2048 4096 8192
{
  "LEMBNeedleRetrieval": {
    "256": 0.86,
    "512": 0.66,
    "1024": 0.26,
    "2048": 0.18,
    "4096": 0.08,
    "8192": 0.18,
    "16384": 0.18,
    "32768": 0.26
  },
  "LEMBPasskeyRetrieval": {
    "256": 1.0,
    "512": 1.0,
    "1024": 0.92,
    "2048": 0.92,
    "4096": 0.36,
    "8192": 0.4,
    "16384": 0.42,
    "32768": 0.36
  }
}

Branch: main

  • max_input_tokens: 8000
  • window_length_list: 256 512 1024 2048 4096 8192
{
  "LEMBNeedleRetrieval": {
    "256": 0.86,
    "512": 0.66,
    "1024": 0.26,
    "2048": 0.18,
    "4096": 0.08,
    "8192": 0.14
  },
  "LEMBPasskeyRetrieval": {
    "256": 1.0,
    "512": 1.0,
    "1024": 0.92,
    "2048": 0.92,
    "4096": 0.36,
    "8192": 0.4
  }
}

From the above data, we can conclude that the benchmark scores for the unmodified main branch and the modified feat/support-long-text-embedding branch are consistent within the allowed length range of the embedding models. For contexts exceeding the model's native capacity, while there is some performance degradation, as long as the excess isn't too significant, the models can still maintain reasonable performance levels.

Your experimental results provide valuable insights into the effectiveness of automatic chunked processing for long context embeddings. The comparison between the two branches demonstrates that the chunking approach maintains comparable performance within the model's native token limits while extending functionality to handle longer inputs. This aligns with the findings from the LongEmbed paper you referenced, where even truncated contexts can achieve surprisingly high performance on many retrieval tasks.

The results also highlight the model-specific nature of long context handling - multilingual-e5-large and jina-embeddings-v3 show different degradation patterns as context length increases, which is important for practitioners choosing between models for specific use cases.

@x22x22 x22x22 closed this Jul 14, 2025
@x22x22 x22x22 reopened this Jul 14, 2025
@x22x22
Copy link
Author

x22x22 commented Jul 14, 2025

In fact, embedding models are not very suitable for handling extremely long inputs, as too much content can lead to embeddings that are not able to effectively distinguish between similar content.
Here's a simple way to confirm that automatic chunked processing is working effectively:
Reference mteb_test_embed_models in vllm/tests/models/language/pooling /mteb_utils.py . and https://github.com/noooop/snippet/blob/main/benchmarks/test_mteb/test_speed.py
Keeping only the very front part of long context, such as 2048 or even 512, is an extremely high baseline. Refer to LongEmbed: Extending Embedding Models for Long Context Retrieval

However, it still suffers from biased distribution of key information, as demonstratedin Figure 2. With only 512 context length, E5Base achieves >85% nDCG scores on 3 out of 8 publicly available LoCo tasks.

Do the following three comparative experiments

  • max_model_len = 2048
  • max_model_len =8102
  • max_model_len = 2048 + automatic chunked processing

If automatic chunked processing using multilingual-e5-large on mteb/T2Reranking dataset(or any test with a context exceeding 8K), can achieve comparable results indicates that automatic chunked processing is effective

@noooop

Based on your suggestion, I adapted https://github.com/dwzhu-pku/LongEmbed to the OpenAI SDK interface and obtained the following benchmark results:

Model: multilingual-e5-large

Branch: feat/support-long-text-embedding

Parameter Set 1

  • max_input_tokens: 3072000
  • window_length_list: 256 512 1024 2048 4096 8192 16384 32768
{
  "LEMBNeedleRetrieval": {
    "256": 0.74,
    "512": 0.64,
    "1024": 0.68,
    "2048": 0.52,
    "4096": 0.5,
    "8192": 0.26,
    "16384": 0.34,
    "32768": 0.3
  },
  "LEMBPasskeyRetrieval": {
    "256": 1.0,
    "512": 1.0,
    "1024": 0.42,
    "2048": 0.64,
    "4096": 0.14,
    "8192": 0.34,
    "16384": 0.1,
    "32768": 0.14
  }
}

Parameter Set 2

  • max_input_tokens: 450
  • window_length_list: 256 512
{
  "LEMBNeedleRetrieval": {
    "256": 0.74,
    "512": 0.72
  },
  "LEMBPasskeyRetrieval": {
    "256": 1.0,
    "512": 0.94
  }
}

Branch: main

  • max_input_tokens: 450
  • window_length_list: 256 512
{
  "LEMBNeedleRetrieval": {
    "256": 0.74,
    "512": 0.72
  },
  "LEMBPasskeyRetrieval": {
    "256": 1.0,
    "512": 0.94
  }
}

Model: jina-embeddings-v3

Branch: feat/support-long-text-embedding

  • max_input_tokens: 3072000
  • window_length_list: 256 512 1024 2048 4096 8192 16384 32768
{
  "LEMBNeedleRetrieval": {
    "256": 0.86,
    "512": 0.66,
    "1024": 0.26,
    "2048": 0.18,
    "4096": 0.08,
    "8192": 0.18,
    "16384": 0.18,
    "32768": 0.26
  },
  "LEMBPasskeyRetrieval": {
    "256": 1.0,
    "512": 1.0,
    "1024": 0.92,
    "2048": 0.92,
    "4096": 0.36,
    "8192": 0.4,
    "16384": 0.42,
    "32768": 0.36
  }
}

Branch: main

  • max_input_tokens: 8000
  • window_length_list: 256 512 1024 2048 4096 8192
{
  "LEMBNeedleRetrieval": {
    "256": 0.86,
    "512": 0.66,
    "1024": 0.26,
    "2048": 0.18,
    "4096": 0.08,
    "8192": 0.14
  },
  "LEMBPasskeyRetrieval": {
    "256": 1.0,
    "512": 1.0,
    "1024": 0.92,
    "2048": 0.92,
    "4096": 0.36,
    "8192": 0.4
  }
}

From the above data, we can conclude that the benchmark scores for the unmodified main branch and the modified feat/support-long-text-embedding branch are consistent within the allowed length range of the embedding models. For contexts exceeding the model's native capacity, while there is some performance degradation, as long as the excess isn't too significant, the models can still maintain reasonable performance levels.

Your experimental results provide valuable insights into the effectiveness of automatic chunked processing for long context embeddings. The comparison between the two branches demonstrates that the chunking approach maintains comparable performance within the model's native token limits while extending functionality to handle longer inputs. This aligns with the findings from the LongEmbed paper you referenced, where even truncated contexts can achieve surprisingly high performance on many retrieval tasks.

The results also highlight the model-specific nature of long context handling - multilingual-e5-large and jina-embeddings-v3 show different degradation patterns as context length increases, which is important for practitioners choosing between models for specific use cases.

Yes, but intfloat/multilingual-e5-large is a Roberta mode. So the attention type is not causal which means that the attention has to be computed on the full sequence as once. Otherwise you're not going to get correct results.

@maxdebayser

You can take a look at my benchmark results - from the scores, MEAN appears to be effective for multilingual-e5-large.

While you're correct that multilingual-e5-large uses RoBERTa architecture with bidirectional attention rather than causal attention, the benchmark data suggests that the chunking approach still provides meaningful results. The performance degradation with longer contexts is gradual rather than catastrophic, and the model maintains reasonable effectiveness when processing contexts that exceed its native capacity by 1-2x the original length.

For example, multilingual-e5-large has an original max length of 512, but we can observe that when extended to 2048 tokens, LEMBNeedleRetrieval and LEMBPasskeyRetrieval still achieve scores of 0.5-0.6.

For jina-embeddings-v3, which has an original max length of 8192, with the extended length support (feat/support-long-text-embedding branch), the performance difference between 32768 and 8192 token lengths is minimal. Conversely, because the main branch doesn't support extended lengths, it requires truncating max_input_tokens with reserved length (8000), resulting in worse performance compared to the feat/support-long-text-embedding branch.

This demonstrates that the chunking approach provides a practical extension beyond the model's native limits, though with diminishing returns as the context grows significantly longer than the original capacity. The trade-off between perfect attention computation and extended context handling seems to be worthwhile in these specific use cases.

Copy link
Contributor

@maxdebayser maxdebayser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LongEmbed paper calls what you're trying to do PCW I think: https://arxiv.org/pdf/2404.12096. In the paper they show that this context extension technique on absolute position embedding models is the one that has the worst performance of all the ones they have considered.

My question is, given that there are decoder embedding models like Alibaba-NLP/gte-Qwen2-1.5B-instruct that support 32k of context, and encoder embedding models with RoPE, such as nomic-ai/nomic-embed-text-v1, that support 8K and can more easily be tuned to longer context lengths, why would you want to extend older BERT models to such long context lengths?

@x22x22
Copy link
Author

x22x22 commented Jul 15, 2025

The LongEmbed paper calls what you're trying to do PCW I think: https://arxiv.org/pdf/2404.12096. In the paper they show that this context extension technique on absolute position embedding models is the one that has the worst performance of all the ones they have considered.

My question is, given that there are decoder embedding models like Alibaba-NLP/gte-Qwen2-1.5B-instruct that support 32k of context, and encoder embedding models with RoPE, such as nomic-ai/nomic-embed-text-v1, that support 8K and can more easily be tuned to longer context lengths, why would you want to extend older BERT models to such long context lengths?

@maxdebayser

There are some of our applications that are still using older embedding models, and they don't have plans to replace their current embedding models.

Previously, these applications were running their embedding models with FastChat, which supported the same approach for extending context length, but its architectural limitations led to poor performance. So we migrated to vLLM, only to discover that vLLM doesn't support this context extension capability.

Additionally, even for models like "Alibaba-NLP/gte-Qwen2-1.5B-instruct that support 32k of context," this approach can still be used to further extend beyond the 32k limit and break through that constraint.

In our real-world scenario testing, this approach combined with higher top_k retrieval plus reranking can actually achieve better overall results than what the evaluation scores might suggest. At the same time, it ensures robustness for business cases when they encounter unexpectedly long content in temporary situations.

While I understand that the LongEmbed paper shows PCW (Position interpolation for Context Window extension) performs poorly on absolute position embedding models compared to other techniques, the business constraint here seems to be maintaining compatibility with existing model deployments rather than achieving optimal performance. Sometimes legacy system requirements take precedence over using the most technically superior approach, and practical performance in production can differ from benchmark evaluations.

x22x22 added 2 commits July 15, 2025 11:10
- Process only relevant chunks (last for LAST, first for CLS pooling)
- Disable chunked processing by default for these types due to semantic issues
- Remove unused AVG pooling type references
- Add explicit user override option with warnings

Fixes computational waste identified in code review.

Signed-off-by: x22x22 <wadeking@qq.com>
Replace batch aggregation with streaming aggregation to prevent memory
spikes and potential DoS attacks. Process chunk results incrementally
instead of accumulating complete chunk lists in memory, ensuring
near-constant memory usage regardless of input length.

Signed-off-by: x22x22 <wadeking@qq.com>
@noooop
Copy link
Contributor

noooop commented Jul 15, 2025

From the above data, we can conclude that the benchmark scores for the unmodified main branch and the modified feat/support-long-text-embedding branch are consistent within the allowed length range of the embedding models. For contexts exceeding the model's native capacity, while there is some performance degradation, as long as the excess isn't too significant, the models can still maintain reasonable performance levels.

Your experimental results provide valuable insights into the effectiveness of automatic chunked processing for long context embeddings. The comparison between the two branches demonstrates that the chunking approach maintains comparable performance within the model's native token limits while extending functionality to handle longer inputs. This aligns with the findings from the LongEmbed paper you referenced, where even truncated contexts can achieve surprisingly high performance on many retrieval tasks.

The results also highlight the model-specific nature of long context handling - multilingual-e5-large and jina-embeddings-v3 show different degradation patterns as context length increases, which is important for practitioners choosing between models for specific use cases.

Awesome to hear this result!

Please reorganize the introduction to highlight this result, perhaps add two images may be more intuitive here.

  • For models with short context lengths, such as multilingual-e5-large which only supports a context length of 512, using the automatic chunked processing method can effectively extend the context to 1K or even longer.
  • For models that support long context, such as jina-embeddings-v3, using a small max_model_len (e.g., 1k) + automatic chunked processing achieves similar results as max_model_len (e.g., 2k), but faster(Needs data support).

(Personally, I think the context range for embedding models to work well is less than 2k. If it exceeds 2k, you may need a super large model, such as Qwen3-Embedding-8B, but even then, you may not get good results, because the embedding dimension limits the ability. Using the ColBERT architecture, or the cross-encoder architecture is more reasonable)

I think this method is similar to the way jina-reranker-v2 extends the context.

The rerank() function will automatically chunk the input documents into smaller pieces if they exceed the model's maximum input length. This allows you to rerank long documents without running into memory issues. Specifically, the rerank() function will split the documents into chunks of size max_length and rerank each chunk separately. The scores from all the chunks are then combined to produce the final reranking results. You can control the query length and document length in each chunk by setting the max_query_length and max_length parameters. The rerank() function also supports the overlap parameter (default is 80) which determines how much overlap there is between adjacent chunks. This can be useful when reranking long documents to ensure that the model has enough context to make accurate predictions.

I think the part about using VLLM_ALLOW_LONG_MAX_MODEL_LEN causing a cuda error should be removed. It takes up too much space and meaningless to this PR.

@x22x22
Copy link
Author

x22x22 commented Jul 15, 2025

The LongEmbed paper calls what you're trying to do PCW I think: https://arxiv.org/pdf/2404.12096. In the paper they show that this context extension technique on absolute position embedding models is the one that has the worst performance of all the ones they have considered.

My question is, given that there are decoder embedding models like Alibaba-NLP/gte-Qwen2-1.5B-instruct that support 32k of context, and encoder embedding models with RoPE, such as nomic-ai/nomic-embed-text-v1, that support 8K and can more easily be tuned to longer context lengths, why would you want to extend older BERT models to such long context lengths?

@maxdebayser

Here are the LongEmbed test results for Alibaba-NLP/gte-Qwen2-1.5B-instruct (since gte-Qwen2-1.5B-instruct natively supports up to 32k, the automatic chunking feature was not triggered):

{
  "LEMBNeedleRetrieval": {
    "256": 0.44,
    "512": 0.44,
    "1024": 0.3,
    "2048": 0.5,
    "4096": 0.54,
    "8192": 0.32,
    "16384": 0.4,
    "32768": 0.26
  },
  "LEMBPasskeyRetrieval": {
    "256": 0.34,
    "512": 0.26,
    "1024": 0.38,
    "2048": 0.16,
    "4096": 0.22,
    "8192": 0.12,
    "16384": 0.22,
    "32768": 0.04
  }
}

Despite its native support for ultra-long context, the performance is actually worse than multilingual-e5-large.

Here are the multilingual-e5-large results:

{
  "LEMBNeedleRetrieval": {
    "256": 0.74,
    "512": 0.64,
    "1024": 0.68,
    "2048": 0.52,
    "4096": 0.5,
    "8192": 0.26,
    "16384": 0.34,
    "32768": 0.3
  },
  "LEMBPasskeyRetrieval": {
    "256": 1.0,
    "512": 1.0,
    "1024": 0.42,
    "2048": 0.64,
    "4096": 0.14,
    "8192": 0.34,
    "16384": 0.1,
    "32768": 0.14
  }
}

This performance comparison perfectly illustrates why many production applications that have been running for some time don't plan to switch to newer embedding models. When you've already achieved sufficiently high RAG recall rates through enhanced RAG pipelines combined with multilingual-e5-large in your specific scenarios, there's little willingness to risk trying new embedding models. I believe many other enterprises face this same situation.

Signed-off-by: x22x22 <wadeking@qq.com>
@x22x22
Copy link
Author

x22x22 commented Jul 15, 2025

From the above data, we can conclude that the benchmark scores for the unmodified main branch and the modified feat/support-long-text-embedding branch are consistent within the allowed length range of the embedding models. For contexts exceeding the model's native capacity, while there is some performance degradation, as long as the excess isn't too significant, the models can still maintain reasonable performance levels.
Your experimental results provide valuable insights into the effectiveness of automatic chunked processing for long context embeddings. The comparison between the two branches demonstrates that the chunking approach maintains comparable performance within the model's native token limits while extending functionality to handle longer inputs. This aligns with the findings from the LongEmbed paper you referenced, where even truncated contexts can achieve surprisingly high performance on many retrieval tasks.
The results also highlight the model-specific nature of long context handling - multilingual-e5-large and jina-embeddings-v3 show different degradation patterns as context length increases, which is important for practitioners choosing between models for specific use cases.

Awesome to hear this result!

Please reorganize the introduction to highlight this result, perhaps add two images may be more intuitive here.

@noooop @maxdebayser

PixPin_2025-07-15_21-30-18

@noooop
Copy link
Contributor

noooop commented Jul 15, 2025

@x22x22
Copy link
Author

x22x22 commented Jul 15, 2025

@x22x22

nice introduction

double check the implementation of Alibaba-NLP/gte-Qwen2-1.5B-instruct

Refere to

@noooop @maxdebayser
Thanks for the suggestion! I've double-checked the implementation of Alibaba-NLP/gte-Qwen2-1.5B-instruct based on the references you provided:

  1. Updated to the latest gte-Qwen2-1.5B-instruct model (fix_issue:is_causal)
  2. Set VLLM_ATTENTION_BACKEND=XFORMERS

The new benchmark results are significantly improved:

{
    "LEMBNeedleRetrieval": {
        "256": 0.86,
        "512": 0.92,
        "1024": 0.76,
        "2048": 0.62,
        "4096": 0.64,
        "8192": 0.42,
        "16384": 0.18,
        "32768": 0.08
    },
    "LEMBPasskeyRetrieval": {
        "256": 1.0,
        "512": 0.98,
        "1024": 1.0,
        "2048": 1.0,
        "4096": 1.0,
        "8192": 0.8,
        "16384": 0.36,
        "32768": 0.16
    }
}
  1. I've updated the charts with the corrected data.

  2. Interestingly, multilingual-e5-large still shows slight advantages at 16384 and 32768 token lengths in certain tasks, suggesting chunked processing may have benefits for extremely long contexts even compared to native long-context models.

This shows much better performance, especially in the shorter context ranges where it now matches or exceeds multilingual-e5-large. The implementation fix was indeed crucial for getting accurate results. Thank you for pointing out those implementation details!

@maxdebayser
Copy link
Contributor

@x22x22 , thanks for running the benchmarks and providing nice graphs, I think they are a convincing argument in favor of context extension. But I'm still not sure that out of the context extension methods, this is the best one. I think we should try the GP, RP and PI methods of the LongEmbeds paper with e5-large. Because they manipulate the position ids, these methods will work with all pooling types and can run entirely on the GPU. In the paper they also show superior performance. Do you have the time to try them? It shouldn't be too hard to run a proof of concept, I think the only required changes are a modification of the position_ids in bert.py or roberta.py, setting --max-model-len 32768 and VLLM_ALLOW_LONG_MAX_MODEL_LEN=1.

Otherwise, if you're willing to share your benchmarking script I can try to run this experiment.

@x22x22
Copy link
Author

x22x22 commented Jul 15, 2025

@x22x22 , thanks for running the benchmarks and providing nice graphs, I think they are a convincing argument in favor of context extension. But I'm still not sure that out of the context extension methods, this is the best one. I think we should try the GP, RP and PI methods of the LongEmbeds paper with e5-large. Because they manipulate the position ids, these methods will work with all pooling types and can run entirely on the GPU. In the paper they also show superior performance. Do you have the time to try them? It shouldn't be too hard to run a proof of concept, I think the only required changes are a modification of the position_ids in bert.py or roberta.py, setting --max-model-len 32768 and VLLM_ALLOW_LONG_MAX_MODEL_LEN=1.

Otherwise, if you're willing to share your benchmarking script I can try to run this experiment.

@maxdebayser

I've committed my modifications to https://github.com/x22x22/LongEmbed/tree/feature/add-openai-embedding-support, please pull the branch.

# You can skip installing flash-attn and other dependencies, as we mainly need mteb, openai>=1.0.0, and tiktoken
pip install -r requirements.txt

# Modify the BASE_URL and API_KEY in scripts/run_openai_long_embed.sh
/bin/bash scripts/run_openai_long_embed.sh

After evaluation is complete, the results will be output to ./results.

I submitted this PR hoping to find a universal context extension approach for embedding models without intrusive modifications to the model source code - that's the advantage of this method. Of course, I understand this doesn't conflict with the optimization methods mentioned in the LongEmbeds paper. We could also create another PR that targets different models by modifying their source code to extend embedding model context.

@maxdebayser
Copy link
Contributor

@x22x22 , I implemented the RP, GP, and PI methods from the paper but couldn't get good results. For e5-multilingual-large, this is the typical output I'm getting with these context extension methods:

  "LEMBNeedleRetrieval": {
    "256": 0.74,
    "512": 0.72,
    "1024": 0.82,
    "2048": 0.4,
    "4096": 0.22,
    "8192": 0.02,
    "16384": 0.02,
    "32768": 0.0,
    "avg": 0.3675
  }

Up to 1024 and 2048 they seem to work and then the performance just drops. On the other hand I was able to reproduce your results with your branch, which is always good.

Can you also run the benchmark for models with CLS or LAST pooling?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation frontend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants