Skip to content

Commit 5398bbd

Browse files
committed
Add a chunking processing function that supports long - text embedding, and update relevant documentation and examples. New example scripts and service startup scripts are added to demonstrate how to configure and utilize chunking processing. Update the model configuration to support long - text processing and implement the chunking processing logic in the code.
Signed-off-by: x22x22 <wadeking@qq.com>
1 parent 53fa457 commit 5398bbd

File tree

7 files changed

+966
-4
lines changed

7 files changed

+966
-4
lines changed

docs/models/pooling_models.md

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,90 @@ we attempt to override the default pooler based on its Sentence Transformers con
3232
You can customize the model's pooling method via the `--override-pooler-config` option,
3333
which takes priority over both the model's and Sentence Transformers's defaults.
3434

35+
## Chunked Processing for Long Text
36+
37+
vLLM supports **chunked processing** for embedding models to handle text inputs that exceed the model's maximum token length. This feature automatically splits long text into manageable chunks, processes them separately, and aggregates the results.
38+
39+
### Supported Models
40+
41+
- `intfloat/multilingual-e5-large`
42+
- Other embedding models can be extended to support this feature
43+
44+
### How Chunked Processing Works
45+
46+
1. **Automatic Detection**: When input text exceeds `max_model_len`, chunked processing is triggered
47+
2. **Smart Chunking**: Text is split at token boundaries to maintain semantic integrity
48+
3. **Parallel Processing**: Each chunk is processed independently through the model
49+
4. **Intelligent Aggregation**: Results are combined using weighted averaging based on chunk token counts
50+
5. **Consistent Output**: Final embeddings maintain the same dimensionality as standard processing
51+
52+
### Configuration
53+
54+
Enable chunked processing by setting `enable_chunked_processing: true` in the pooler configuration:
55+
56+
```bash
57+
vllm serve intfloat/multilingual-e5-large \
58+
--task embed \
59+
--override-pooler-config '{"pooling_type": "CLS", "normalize": true, "enable_chunked_processing": true}' \
60+
--max-model-len 10240 \
61+
--trust-remote-code
62+
```
63+
64+
### Aggregation Algorithm
65+
66+
The chunked processing uses a FastChat-inspired weighted averaging algorithm:
67+
68+
```python
69+
# Weighted average: sum(embedding_i * token_count_i) / total_tokens
70+
weighted_sum = sum(embeddings[i] * weights[i] for i in range(num_chunks))
71+
final_embedding = weighted_sum / sum(weights)
72+
```
73+
74+
This ensures that longer chunks contribute proportionally more to the final representation.
75+
76+
### Performance Characteristics
77+
78+
| Aspect | Short Text (≤ max_len) | Long Text (> max_len) |
79+
|--------|------------------------|----------------------|
80+
| **Processing Time** | Standard | Increased (multiple inference calls) |
81+
| **Memory Usage** | Standard | Reduced (chunks processed separately) |
82+
| **Quality** | Standard | Maintains semantic representation |
83+
| **Compatibility** | Full | Full (backward compatible) |
84+
85+
### Example Usage
86+
87+
```python
88+
from openai import OpenAI
89+
90+
client = OpenAI(
91+
api_key="your-api-key",
92+
base_url="http://localhost:31090/v1"
93+
)
94+
95+
# This will automatically use chunked processing if text is too long
96+
response = client.embeddings.create(
97+
input="Very long text that exceeds the model's maximum context length..." * 1000,
98+
model="multilingual-e5-large"
99+
)
100+
101+
print(f"Embedding dimension: {len(response.data[0].embedding)}")
102+
```
103+
104+
### Logging and Monitoring
105+
106+
When chunked processing is active, you'll see informative log messages:
107+
108+
```
109+
INFO: Input length 15000 exceeds max_model_len 10240, will use chunked processing
110+
INFO: Split input of 15000 tokens into 2 chunks
111+
```
112+
113+
### Limitations
114+
115+
- **Increased Latency**: Processing multiple chunks takes longer than single-chunk processing
116+
- **Model Support**: Currently limited to specific embedding models
117+
- **Context Boundaries**: Chunking may split related content, though weighted averaging helps preserve overall semantics
118+
35119
## Offline Inference
36120

37121
The [LLM][vllm.LLM] class provides various methods for offline inference.
@@ -170,7 +254,7 @@ vllm serve jinaai/jina-embeddings-v3 --trust-remote-code
170254
You can change the output dimensions of embedding models that support Matryoshka Embeddings by using the dimensions parameter.
171255

172256
```text
173-
curl http://127.0.0.1:8000/v1/embeddings \
257+
curl http://127.0.0.1:31090/v1/embeddings \
174258
-H 'accept: application/json' \
175259
-H 'Content-Type: application/json' \
176260
-d '{

docs/models/supported_models.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ Specified using `--task embed`.
418418
| `GteNewModel` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | |
419419
| `ModernBertModel` | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | | |
420420
| `NomicBertModel` | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | | | |
421-
| `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
421+
| `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, `intfloat/multilingual-e5-large` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
422422
| `Qwen2Model`, `Qwen2ForCausalLM` | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
423423
| `Qwen3Model`, `Qwen3ForCausalLM` | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ | ✅︎ |
424424
| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | |
@@ -437,6 +437,9 @@ Specified using `--task embed`.
437437
!!! note
438438
The second-generation GTE model (mGTE-TRM) is named `NewModel`. The name `NewModel` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewModel"]}'` to specify the use of the `GteNewModel` architecture.
439439

440+
!!! note
441+
`intfloat/multilingual-e5-large` supports **long text embedding** with chunked processing. When input text exceeds the model's maximum length, the model automatically splits the input into chunks and processes them separately, then aggregates the results. Enable this feature with `--override-pooler-config '{"pooling_type": "CLS", "normalize": true, "enable_chunked_processing": true}'`. See the [Chunked Processing section](pooling_models.md#chunked-processing-for-long-text) for more details.
442+
440443
If your model is not in the above list, we will try to automatically convert the model using
441444
[as_embedding_model][vllm.model_executor.models.adapters.as_embedding_model]. By default, the embeddings
442445
of the whole prompt are extracted from the normalized hidden state corresponding to the last token.
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Long Text Embedding with Chunked Processing
2+
3+
This directory contains examples for using vLLM's **chunked processing** feature to handle long text embedding that exceeds the model's maximum context length.
4+
5+
## 🚀 Quick Start
6+
7+
### 1. Start the Server
8+
9+
Use the provided script to start a vLLM server with chunked processing enabled:
10+
11+
```bash
12+
# Basic usage
13+
./openai_embedding_long_text_service.sh
14+
15+
# Custom configuration
16+
MODEL_NAME="intfloat/multilingual-e5-large" \
17+
PORT=31090 \
18+
MAX_MODEL_LEN=10240 \
19+
./openai_embedding_long_text_service.sh
20+
```
21+
22+
### 2. Test Long Text Embedding
23+
24+
Run the comprehensive test client:
25+
26+
```bash
27+
python openai_embedding_long_text_client.py
28+
```
29+
30+
## 📁 Files
31+
32+
| File | Description |
33+
|------|-------------|
34+
| `openai_embedding_long_text_service.sh` | Server startup script with chunked processing enabled |
35+
| `openai_embedding_long_text_client.py` | Comprehensive test client for long text embedding |
36+
| `openai_embedding_client.py` | Basic embedding client (updated with chunked processing info) |
37+
38+
## ⚙️ Configuration
39+
40+
### Server Configuration
41+
42+
The key parameter for chunked processing is in the `--override-pooler-config`:
43+
44+
```json
45+
{
46+
"pooling_type": "CLS",
47+
"normalize": true,
48+
"enable_chunked_processing": true
49+
}
50+
```
51+
52+
### Environment Variables
53+
54+
| Variable | Default | Description |
55+
|----------|---------|-------------|
56+
| `MODEL_NAME` | `intfloat/multilingual-e5-large` | Embedding model to use |
57+
| `PORT` | `31090` | Server port |
58+
| `GPU_COUNT` | `1` | Number of GPUs to use |
59+
| `MAX_MODEL_LEN` | `10240` | Maximum model context length |
60+
| `API_KEY` | `EMPTY` | API key for authentication |
61+
62+
## 🔧 How It Works
63+
64+
1. **Automatic Detection**: When input text exceeds `max_model_len`, chunked processing is triggered
65+
2. **Smart Chunking**: Text is split at token boundaries to maintain semantic integrity
66+
3. **Independent Processing**: Each chunk is processed separately through the model
67+
4. **Weighted Aggregation**: Results are combined using token count-based weighted averaging
68+
5. **Consistent Output**: Final embeddings maintain the same dimensionality as standard processing
69+
70+
## 📊 Performance Characteristics
71+
72+
| Text Length | Processing Method | Memory Usage | Speed |
73+
|-------------|------------------|--------------|-------|
74+
| ≤ max_len | Standard | Normal | Fast |
75+
| > max_len | Chunked | Reduced per chunk | Slower (multiple inferences) |
76+
77+
## 🧪 Test Cases
78+
79+
The test client demonstrates:
80+
81+
-**Short text**: Normal processing (baseline)
82+
-**Medium text**: Single chunk processing
83+
-**Long text**: Multi-chunk processing with aggregation
84+
-**Very long text**: Many chunks processing
85+
-**Batch processing**: Mixed-length inputs in one request
86+
-**Consistency**: Reproducible results across runs
87+
88+
## 🐛 Troubleshooting
89+
90+
### Common Issues
91+
92+
1. **Chunked processing not enabled**:
93+
94+
```
95+
ValueError: This model's maximum context length is 512 tokens...
96+
```
97+
98+
**Solution**: Ensure `enable_chunked_processing: true` in pooler config
99+
100+
2. **Memory errors**:
101+
102+
```
103+
RuntimeError: CUDA out of memory
104+
```
105+
106+
**Solution**: Reduce `MAX_MODEL_LEN` or use fewer GPUs
107+
108+
1. **Slow processing**:
109+
**Expected**: Long text takes more time due to multiple inference calls
110+
111+
### Debug Information
112+
113+
Server logs show chunked processing activity:
114+
115+
```
116+
INFO: Input length 15000 exceeds max_model_len 10240, will use chunked processing
117+
INFO: Split input of 15000 tokens into 2 chunks
118+
```
119+
120+
## 📚 Additional Resources
121+
122+
- [Pooling Models Documentation](../../docs/models/pooling_models.md#chunked-processing-for-long-text)
123+
- [Supported Models List](../../docs/models/supported_models.md#text-embedding)
124+
- [Original Feature Documentation](../../README_CHUNKED_PROCESSING.md)
125+
126+
## 🤝 Contributing
127+
128+
To extend chunked processing support to other embedding models:
129+
130+
1. Check model compatibility with the pooling architecture
131+
2. Test with various text lengths
132+
3. Validate embedding quality compared to single-chunk processing
133+
4. Submit PR with test cases and documentation updates
134+
135+
---
136+
137+
**Note**: Chunked processing is currently supported for specific embedding models. See the [supported models documentation](../../docs/models/supported_models.md#chunked-processing-for-long-text) for the complete list.

0 commit comments

Comments
 (0)