diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md index f0de84a66f8..e4e1436c545 100644 --- a/docs/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -32,6 +32,137 @@ we attempt to override the default pooler based on its Sentence Transformers con You can customize the model's pooling method via the `--override-pooler-config` option, which takes priority over both the model's and Sentence Transformers's defaults. +## Chunked Processing for Long Text + +vLLM supports **chunked processing** for embedding models to handle text inputs that exceed the model's maximum token length. This feature automatically splits long text into manageable chunks, processes them separately, and aggregates the results. + +### Supported Models + +Chunked processing is supported for the following embedding models: + +- `intfloat/multilingual-e5-large` (Recommended pool type: `MEAN`) +- `jinaai/jina-embeddings-v3` (Recommended pool type: `MEAN`) +- `jinaai/jina-embeddings-v4-vllm-retrieval` (Recommended pool type: `MEAN`) +- `Qwen/Qwen3-Embedding-4B` (Recommended pool type: `MEAN`) + +Other embedding models can be extended to support this feature by ensuring proper pooling type compatibility. + +### How Chunked Processing Works + +1. **Flexible Input Validation**: Configure `max_embed_len` to accept inputs longer than `max_model_len` without environment variables +2. **Smart Chunking**: Text is split based on `max_position_embeddings` to maintain semantic integrity +3. **Parallel Processing**: Each chunk is processed independently through the model +4. **Intelligent Aggregation**: Results are combined using weighted averaging based on chunk token counts +5. **Consistent Output**: Final embeddings maintain the same dimensionality as standard processing + +### Configuration + +Enable chunked processing and configure maximum embedding input length: + +```bash +vllm serve intfloat/multilingual-e5-large \ + --task embed \ + --override-pooler-config '{"pooling_type": "MEAN", "normalize": true, "enable_chunked_processing": true, "max_embed_len": 3072000}' \ + --trust-remote-code +``` + +#### Configuration Parameters + +- `enable_chunked_processing`: Enable chunked processing for long inputs (default: `false`) +- `max_embed_len`: Maximum input length allowed for embedding generation (default: `null`) + - When set, allows inputs longer than `max_model_len` without requiring `VLLM_ALLOW_LONG_MAX_MODEL_LEN` + - Inputs exceeding `max_embed_len` are rejected with clear error messages + - Chunking is triggered when inputs exceed `max_position_embeddings` + +### Aggregation Algorithm + +The chunked processing uses a FastChat-inspired weighted averaging algorithm: + +```python +# Weighted average: sum(embedding_i * token_count_i) / total_tokens +weighted_sum = sum(embeddings[i] * weights[i] for i in range(num_chunks)) +final_embedding = weighted_sum / sum(weights) +``` + +This ensures that longer chunks contribute proportionally more to the final representation. + +### Performance Characteristics + +| Aspect | Short Text (≤ max_position_embeddings) | Long Text (> max_position_embeddings) | +|--------|----------------------------------------|---------------------------------------| +| **Processing Time** | Standard | Increased (multiple inference calls) | +| **Memory Usage** | Standard | Reduced (chunks processed separately) | +| **Quality** | Standard | Maintains semantic representation | +| **Compatibility** | Full | Full (backward compatible) | +| **Input Validation** | Standard max_model_len check | Extended max_embed_len check | + +#### Extreme Long Text Support + +With the enhanced `max_embed_len` configuration (up to 3M+ tokens), you can process: +- **Complete Documents**: Research papers, legal contracts, technical manuals +- **Large Codebases**: Entire repositories and documentation +- **Books and Literature**: Full chapters or small books +- **Multi-document Analysis**: Combined content for comprehensive understanding + +### Example Usage + +#### Basic Configuration + +```python +from openai import OpenAI + +client = OpenAI( + api_key="your-api-key", + base_url="http://localhost:31090/v1" +) + +# This will automatically use chunked processing for very long text +# max_embed_len=3072000 allows inputs up to 3M+ tokens +response = client.embeddings.create( + input="Very long text that exceeds the model's position embeddings..." * 5000, + model="multilingual-e5-large" +) + +print(f"Embedding dimension: {len(response.data[0].embedding)}") +``` + +#### Alternative Model Configurations + +```bash +# For Jina embeddings v3 (optimized for performance) +vllm serve jinaai/jina-embeddings-v3 \ + --task embed \ + --override-pooler-config '{"pooling_type": "MEAN", "normalize": true, "enable_chunked_processing": true, "max_embed_len": 1048576}' \ + --trust-remote-code + +# For Jina embeddings v4 (latest retrieval model) +vllm serve jinaai/jina-embeddings-v4-vllm-retrieval \ + --task embed \ + --override-pooler-config '{"pooling_type": "MEAN", "normalize": true, "enable_chunked_processing": true, "max_embed_len": 2097152}' \ + --trust-remote-code + +# For Qwen3 Embedding (large-scale multilingual) +vllm serve Qwen/Qwen3-Embedding-4B \ + --task embed \ + --override-pooler-config '{"pooling_type": "MEAN", "normalize": true, "enable_chunked_processing": true, "max_embed_len": 1572864}' \ + --trust-remote-code +``` + +### Logging and Monitoring + +When chunked processing is active, you'll see informative log messages: + +``` +INFO: Input length 100000 exceeds max_position_embeddings 512, will use chunked processing +INFO: Split input of 100000 tokens into 196 chunks (max_chunk_size: 512) +``` + +### Limitations + +- **Increased Latency**: Processing multiple chunks takes longer than single-chunk processing +- **Model Support**: Currently limited to specific embedding models +- **Context Boundaries**: Chunking may split related content, though weighted averaging helps preserve overall semantics + ## Offline Inference The [LLM][vllm.LLM] class provides various methods for offline inference. @@ -170,7 +301,7 @@ vllm serve jinaai/jina-embeddings-v3 --trust-remote-code You can change the output dimensions of embedding models that support Matryoshka Embeddings by using the dimensions parameter. ```text -curl http://127.0.0.1:8000/v1/embeddings \ +curl http://127.0.0.1:31090/v1/embeddings \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index ddc920aeb2d..a9597e45fd5 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -418,7 +418,7 @@ Specified using `--task embed`. | `GteNewModel` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | | | `ModernBertModel` | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | | | | `NomicBertModel` | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | | | | -| `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, `intfloat/multilingual-e5-large` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2Model`, `Qwen2ForCausalLM` | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen3Model`, `Qwen3ForCausalLM` | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | | @@ -437,6 +437,9 @@ Specified using `--task embed`. !!! note The second-generation GTE model (mGTE-TRM) is named `NewModel`. The name `NewModel` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewModel"]}'` to specify the use of the `GteNewModel` architecture. +!!! note + `intfloat/multilingual-e5-large` supports **long text embedding** with chunked processing. When input text exceeds the model's maximum length, the model automatically splits the input into chunks and processes them separately, then aggregates the results. Enable this feature with `--override-pooler-config '{"pooling_type": "CLS", "normalize": true, "enable_chunked_processing": true}'`. See the [Chunked Processing section](pooling_models.md#chunked-processing-for-long-text) for more details. + If your model is not in the above list, we will try to automatically convert the model using [as_embedding_model][vllm.model_executor.models.adapters.as_embedding_model]. By default, the embeddings of the whole prompt are extracted from the normalized hidden state corresponding to the last token. diff --git a/examples/online_serving/openai_embedding_long_text.md b/examples/online_serving/openai_embedding_long_text.md new file mode 100644 index 00000000000..c1c044d916b --- /dev/null +++ b/examples/online_serving/openai_embedding_long_text.md @@ -0,0 +1,179 @@ +# Long Text Embedding with Chunked Processing + +This directory contains examples for using vLLM's **chunked processing** feature to handle long text embedding that exceeds the model's maximum context length. + +## 🚀 Quick Start + +### 1. Start the Server + +Use the provided script to start a vLLM server with chunked processing enabled: + +```bash +# Basic usage (supports very long texts up to ~3M tokens) +./openai_embedding_long_text_service.sh + +# Custom configuration with different models +MODEL_NAME="jinaai/jina-embeddings-v3" \ +MAX_EMBED_LEN=1048576 \ +./openai_embedding_long_text_service.sh + +# For extremely long documents +MODEL_NAME="intfloat/multilingual-e5-large" \ +MAX_EMBED_LEN=3072000 \ +./openai_embedding_long_text_service.sh +``` + +### 2. Test Long Text Embedding + +Run the comprehensive test client: + +```bash +python openai_embedding_long_text_client.py +``` + +## 📁 Files + +| File | Description | +|------|-------------| +| `openai_embedding_long_text_service.sh` | Server startup script with chunked processing enabled | +| `openai_embedding_long_text_client.py` | Comprehensive test client for long text embedding | +| `openai_embedding_client.py` | Basic embedding client (updated with chunked processing info) | + +## ⚙️ Configuration + +### Server Configuration + +The key parameters for chunked processing are in the `--override-pooler-config`: + +```json +{ + "pooling_type": "MEAN", + "normalize": true, + "enable_chunked_processing": true, + "max_embed_len": 3072000 +} +``` + +### Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `MODEL_NAME` | `intfloat/multilingual-e5-large` | Embedding model to use (supports multiple models) | +| `PORT` | `31090` | Server port | +| `GPU_COUNT` | `1` | Number of GPUs to use | +| `MAX_EMBED_LEN` | `3072000` | Maximum embedding input length (supports very long documents) | +| `API_KEY` | `EMPTY` | API key for authentication | + +## 🔧 How It Works + +1. **Enhanced Input Validation**: `max_embed_len` allows accepting inputs longer than `max_model_len` without environment variables +2. **Smart Chunking**: Text is split based on `max_position_embeddings` to maintain semantic integrity +3. **Independent Processing**: Each chunk is processed separately through the model +4. **Weighted Aggregation**: Results are combined using token count-based weighted averaging +5. **Consistent Output**: Final embeddings maintain the same dimensionality as standard processing + +### Input Length Handling + +- **Within max_embed_len**: Input is accepted and processed (up to 3M+ tokens) +- **Exceeds max_position_embeddings**: Chunked processing is automatically triggered +- **Exceeds max_embed_len**: Input is rejected with clear error message +- **No environment variables required**: Works without `VLLM_ALLOW_LONG_MAX_MODEL_LEN` + +### Extreme Long Text Support + +With `MAX_EMBED_LEN=3072000`, you can process: +- **Academic papers**: Full research papers with references +- **Legal documents**: Complete contracts and legal texts +- **Books**: Entire chapters or small books +- **Code repositories**: Large codebases and documentation + +## 📊 Performance Characteristics + +| Text Length | Processing Method | Memory Usage | Speed | +|-------------|------------------|--------------|-------| +| ≤ max_position_embeddings | Standard | Normal | Fast | +| > max_position_embeddings, ≤ max_embed_len | Chunked | Reduced per chunk | Slower (multiple inferences) | +| > max_embed_len | Rejected | N/A | Error response | + +## 🧪 Test Cases + +The test client demonstrates: + +- ✅ **Short text**: Normal processing (baseline) +- ✅ **Medium text**: Single chunk processing +- ✅ **Long text**: Multi-chunk processing with aggregation +- ✅ **Very long text**: Many chunks processing +- ✅ **Extreme long text**: Document-level processing (100K+ tokens) +- ✅ **Batch processing**: Mixed-length inputs in one request +- ✅ **Consistency**: Reproducible results across runs + +## 🐛 Troubleshooting + +### Common Issues + +1. **Chunked processing not enabled**: + + ``` + ValueError: This model's maximum position embeddings length is 4096 tokens... + ``` + + **Solution**: Ensure `enable_chunked_processing: true` in pooler config + +2. **Input exceeds max_embed_len**: + + ``` + ValueError: This model's maximum embedding input length is 3072000 tokens... + ``` + + **Solution**: Increase `max_embed_len` in pooler config or reduce input length + +3. **Memory errors**: + + ``` + RuntimeError: CUDA out of memory + ``` + + **Solution**: Reduce chunk size by adjusting model's `max_position_embeddings` or use fewer GPUs + +4. **Slow processing**: + **Expected**: Long text takes more time due to multiple inference calls + +### Debug Information + +Server logs show chunked processing activity: + +``` +INFO: Input length 150000 exceeds max_position_embeddings 4096, will use chunked processing +INFO: Split input of 150000 tokens into 37 chunks (max_chunk_size: 4096) +``` + +## 📚 Additional Resources + +- [Pooling Models Documentation](../../docs/models/pooling_models.md#chunked-processing-for-long-text) +- [Supported Models List](../../docs/models/supported_models.md#text-embedding) +- [Original Feature Documentation](../../README_CHUNKED_PROCESSING.md) + +## 🤝 Contributing + +To extend chunked processing support to other embedding models: + +1. Check model compatibility with the pooling architecture +2. Test with various text lengths +3. Validate embedding quality compared to single-chunk processing +4. Submit PR with test cases and documentation updates + +## 🆕 Enhanced Features + +### max_embed_len Parameter + +The new `max_embed_len` parameter provides: + +- **Simplified Configuration**: No need for `VLLM_ALLOW_LONG_MAX_MODEL_LEN` environment variable +- **Flexible Input Validation**: Accept inputs longer than `max_model_len` up to `max_embed_len` +- **Extreme Length Support**: Process documents with millions of tokens +- **Clear Error Messages**: Better feedback when inputs exceed limits +- **Backward Compatibility**: Existing configurations continue to work + +--- + +**Note**: Chunked processing is currently supported for specific embedding models. See the [supported models documentation](../../docs/models/supported_models.md#chunked-processing-for-long-text) for the complete list. diff --git a/examples/online_serving/openai_embedding_long_text_client.py b/examples/online_serving/openai_embedding_long_text_client.py new file mode 100644 index 00000000000..fb645ed975e --- /dev/null +++ b/examples/online_serving/openai_embedding_long_text_client.py @@ -0,0 +1,348 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Example script demonstrating long text embedding with chunked processing in vLLM. + +This example shows how to use vLLM's chunked processing feature to handle text +inputs that exceed the model's maximum token length. The feature automatically +splits long text into chunks and aggregates the results. + +Prerequisites: +1. Start vLLM server with chunked processing enabled: + + vllm serve intfloat/multilingual-e5-large \ + --task embed \ + --override-pooler-config \ + '{"pooling_type": "CLS", "normalize": true, ' \ + '"enable_chunked_processing": true, "max_embed_len": 10240}' \ + --served-model-name multilingual-e5-large \ + --trust-remote-code \ + --port 31090 \ + --api-key your-api-key + +2. Install required dependencies: + pip install openai requests +""" + +import time + +import numpy as np +from openai import OpenAI + +# Configuration +API_KEY = "your-api-key" # Replace with your actual API key +BASE_URL = "http://localhost:31090/v1" +MODEL_NAME = "multilingual-e5-large" + + +def generate_long_text(base_text: str, repeat_count: int) -> str: + """Generate long text by repeating base text.""" + return base_text * repeat_count + + +def test_embedding_with_different_lengths(): + """Test embedding generation with different text lengths.""" + client = OpenAI(api_key=API_KEY, base_url=BASE_URL) + + # Test cases with different text lengths + test_cases = [ + { + "name": "Short Text", + "text": "Hello, this is a short text for embedding.", + "expected_chunks": 1, + }, + { + "name": "Medium Text", + "text": generate_long_text( + "This is a medium-length text that should fit within the " + "model's context window. " * 20, + 2, + ), + "expected_chunks": 1, + }, + { + "name": "Long Text (2 chunks)", + "text": generate_long_text( + "This is a very long text that will exceed the model's " + "maximum context length and trigger chunked processing. " * 50, + 5, + ), + "expected_chunks": 2, + }, + { + "name": "Very Long Text (3+ chunks)", + "text": generate_long_text( + "This text is extremely long and will definitely " + "require multiple chunks for processing. " * 100, + 10, + ), + "expected_chunks": 3, + }, + ] + + print("🧪 Testing vLLM Long Text Embedding with Chunked Processing") + print("=" * 70) + + for i, test_case in enumerate(test_cases, 1): + print(f"\n📝 Test {i}: {test_case['name']}") + print(f"Text length: {len(test_case['text'])} characters") + + try: + start_time = time.time() + + response = client.embeddings.create( + input=test_case["text"], model=MODEL_NAME, encoding_format="float" + ) + + end_time = time.time() + processing_time = end_time - start_time + + # Extract embedding data + embedding = response.data[0].embedding + embedding_dim = len(embedding) + + print("✅ Success!") + print(f" - Embedding dimension: {embedding_dim}") + print(f" - Processing time: {processing_time:.2f}s") + print(f" - Expected chunks: ~{test_case['expected_chunks']}") + print(f" - First 5 values: {embedding[:5]}") + + except Exception as e: + print(f"❌ Failed: {str(e)}") + + +def test_batch_embedding(): + """Test batch embedding with mixed-length inputs.""" + client = OpenAI(api_key=API_KEY, base_url=BASE_URL) + + print("\n🔄 Testing Batch Embedding with Mixed Lengths") + print("=" * 50) + + # Mix of short and long texts + batch_inputs = [ + "Short text 1", + generate_long_text("Medium length text that fits in one chunk. " * 20, 1), + "Another short text", + generate_long_text("Long text requiring chunked processing. " * 100, 5), + ] + + try: + start_time = time.time() + + response = client.embeddings.create( + input=batch_inputs, model=MODEL_NAME, encoding_format="float" + ) + + end_time = time.time() + processing_time = end_time - start_time + + print("✅ Batch processing successful!") + print(f" - Number of inputs: {len(batch_inputs)}") + print(f" - Number of embeddings: {len(response.data)}") + print(f" - Total processing time: {processing_time:.2f}s") + print( + f" - Average time per input: {processing_time / len(batch_inputs):.2f}s" + ) + + for i, data in enumerate(response.data): + input_length = len(batch_inputs[i]) + embedding_dim = len(data.embedding) + print( + f" - Input {i + 1}: {input_length} chars → {embedding_dim}D embedding" + ) + + except Exception as e: + print(f"❌ Batch processing failed: {str(e)}") + + +def test_multiple_long_texts_batch(): + """Test batch processing with multiple long texts to verify chunk ID uniqueness.""" + client = OpenAI(api_key=API_KEY, base_url=BASE_URL) + + print("\n🔧 Testing Multiple Long Texts in Batch (Chunk ID Fix Verification)") + print("=" * 70) + + # Create multiple distinct long texts that will all require chunking + long_texts = [ + generate_long_text( + "First long document about artificial intelligence and machine learning. " + * 80, + 6, + ), + generate_long_text( + "Second long document about natural language processing and transformers. " + * 80, + 6, + ), + generate_long_text( + "Third long document about computer vision and neural networks. " * 80, 6 + ), + ] + + # Add some short texts to mix things up + batch_inputs = [ + "Short text before long texts", + long_texts[0], + "Short text between long texts", + long_texts[1], + long_texts[2], + "Short text after long texts", + ] + + print("📊 Batch composition:") + for i, text in enumerate(batch_inputs): + length = len(text) + text_type = "Long (will be chunked)" if length > 5000 else "Short" + print(f" - Input {i + 1}: {length} chars ({text_type})") + + try: + start_time = time.time() + + response = client.embeddings.create( + input=batch_inputs, model=MODEL_NAME, encoding_format="float" + ) + + end_time = time.time() + processing_time = end_time - start_time + + print("\n✅ Multiple long texts batch processing successful!") + print(f" - Number of inputs: {len(batch_inputs)}") + print(f" - Number of embeddings returned: {len(response.data)}") + print(f" - Total processing time: {processing_time:.2f}s") + + # Verify each embedding is different (no incorrect aggregation) + embeddings = [data.embedding for data in response.data] + + if len(embeddings) >= 3: + import numpy as np + + # Compare embeddings of the long texts (indices 1, 3, 4) + long_embeddings = [ + np.array(embeddings[1]), # First long text + np.array(embeddings[3]), # Second long text + np.array(embeddings[4]), # Third long text + ] + + print("\n🔍 Verifying embedding uniqueness:") + for i in range(len(long_embeddings)): + for j in range(i + 1, len(long_embeddings)): + cosine_sim = np.dot(long_embeddings[i], long_embeddings[j]) / ( + np.linalg.norm(long_embeddings[i]) + * np.linalg.norm(long_embeddings[j]) + ) + print( + f" - Similarity between long text {i + 1} and {j + 1}: " + f"{cosine_sim:.4f}" + ) + + if ( + cosine_sim < 0.9 + ): # Different content should have lower similarity + print(" ✅ Good: Embeddings are appropriately different") + else: + print( + " ⚠️ High similarity - may indicate chunk " + "aggregation issue" + ) + + print("\n📋 Per-input results:") + for i, data in enumerate(response.data): + input_length = len(batch_inputs[i]) + embedding_dim = len(data.embedding) + embedding_norm = np.linalg.norm(data.embedding) + print( + f" - Input {i + 1}: {input_length} chars → {embedding_dim}D " + f"embedding (norm: {embedding_norm:.4f})" + ) + + print( + "\n✅ This test verifies the fix for chunk ID collisions in " + "batch processing" + ) + print(" - Before fix: Multiple long texts would have conflicting chunk IDs") + print(" - After fix: Each prompt's chunks have unique IDs with prompt index") + + except Exception as e: + print(f"❌ Multiple long texts batch test failed: {str(e)}") + print(" This might indicate the chunk ID collision bug is present!") + + +def test_embedding_consistency(): + """Test that chunked processing produces consistent results.""" + client = OpenAI(api_key=API_KEY, base_url=BASE_URL) + + print("\n🔍 Testing Embedding Consistency") + print("=" * 40) + + # Use the same long text multiple times + long_text = generate_long_text( + "Consistency test text for chunked processing validation. " * 50, 3 + ) + + embeddings = [] + + try: + for i in range(3): + response = client.embeddings.create( + input=long_text, model=MODEL_NAME, encoding_format="float" + ) + embeddings.append(response.data[0].embedding) + print(f" - Generated embedding {i + 1}") + + # Check consistency (embeddings should be identical) + if len(embeddings) >= 2: + # Calculate similarity between first two embeddings + + emb1 = np.array(embeddings[0]) + emb2 = np.array(embeddings[1]) + + # Cosine similarity + cosine_sim = np.dot(emb1, emb2) / ( + np.linalg.norm(emb1) * np.linalg.norm(emb2) + ) + + print("✅ Consistency test completed!") + print(f" - Cosine similarity between runs: {cosine_sim:.6f}") + print(" - Expected: ~1.0 (identical embeddings)") + + if cosine_sim > 0.999: + print(" - ✅ High consistency achieved!") + else: + print(" - ⚠️ Consistency may vary due to numerical precision") + + except Exception as e: + print(f"❌ Consistency test failed: {str(e)}") + + +def main(): + """Main function to run all tests.""" + print("🚀 vLLM Long Text Embedding Client") + print(f"📡 Connecting to: {BASE_URL}") + print(f"🤖 Model: {MODEL_NAME}") + masked_key = "*" * (len(API_KEY) - 4) + API_KEY[-4:] if len(API_KEY) > 4 else "****" + print(f"🔑 API Key: {masked_key}") + + # Run all test cases + test_embedding_with_different_lengths() + test_batch_embedding() + test_multiple_long_texts_batch() + test_embedding_consistency() + + print("\n" + "=" * 70) + print("🎉 All tests completed!") + print("\n💡 Key Features Demonstrated:") + print(" - ✅ Automatic chunked processing for long text") + print(" - ✅ Seamless handling of mixed-length batches") + print(" - ✅ Multiple long texts in single batch (chunk ID fix)") + print(" - ✅ Consistent embedding generation") + print(" - ✅ Backward compatibility with short text") + print("\n📚 For more information, see:") + print( + " - Documentation: https://docs.vllm.ai/en/latest/models/pooling_models.html" + ) + print(" - Chunked Processing Guide: openai_embedding_long_text.md") + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/openai_embedding_long_text_service.sh b/examples/online_serving/openai_embedding_long_text_service.sh new file mode 100644 index 00000000000..fa78385e782 --- /dev/null +++ b/examples/online_serving/openai_embedding_long_text_service.sh @@ -0,0 +1,145 @@ +#!/bin/bash + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# vLLM Embedding Server with Enhanced Chunked Processing +# This script starts a vLLM server with chunked processing enabled for long text embedding. +# Now supports proper pooling type validation and model-specific configurations. + +set -euo pipefail + +# Configuration +MODEL_NAME=${MODEL_NAME:-"intfloat/multilingual-e5-large"} +MODEL_CODE=${MODEL_CODE:-"multilingual-e5-large"} + +PORT=${PORT:-31090} +GPU_COUNT=${GPU_COUNT:-1} +MAX_EMBED_LEN=${MAX_EMBED_LEN:-3072000} +API_KEY=${API_KEY:-"your-api-key"} + +# Enhanced pooling configuration with model-specific defaults +POOLING_TYPE=${POOLING_TYPE:-"auto"} # auto, MEAN, CLS, LAST +ALLOW_NON_MEAN_CHUNKING=${ALLOW_NON_MEAN_CHUNKING:-"false"} +# export CUDA_VISIBLE_DEVICES=2,3,4,5 + +echo "🚀 Starting vLLM Embedding Server with Enhanced Chunked Processing" +echo "==================================================================" + +# Environment variables for optimization +export VLLM_WORKER_MULTIPROC_METHOD=spawn + +# Function to determine optimal pooling type for known models +get_optimal_pooling_type() { + local model="$1" + case "$model" in + *"e5-"* | *"multilingual-e5"*) + echo "MEAN" # E5 series uses mean pooling + ;; + *"bge-"*) + echo "CLS" # BGE series uses CLS pooling + ;; + *"gte-"*) + echo "MEAN" # GTE series uses mean pooling + ;; + *"sentence-t5"* | *"st5"*) + echo "MEAN" # Sentence-T5 uses mean pooling + ;; + *) + echo "MEAN" # Default to MEAN for unknown models + ;; + esac +} + +# Auto-detect pooling type if not explicitly set +if [ "$POOLING_TYPE" = "auto" ]; then + POOLING_TYPE=$(get_optimal_pooling_type "$MODEL_NAME") + echo "🔍 Auto-detected pooling type: $POOLING_TYPE for model $MODEL_NAME" +fi + +# Display configuration +echo "📋 Configuration:" +echo " - Model: $MODEL_NAME" +echo " - Port: $PORT" +echo " - GPU Count: $GPU_COUNT" +echo " - Enhanced Chunked Processing: ENABLED" +echo " - Max Embed Length: ${MAX_EMBED_LEN} tokens" +echo " - Pooling Type: $POOLING_TYPE + Normalization" +echo " - Allow Non-MEAN Chunking: $ALLOW_NON_MEAN_CHUNKING" +echo "" + +# Validate GPU availability +if command -v nvidia-smi &> /dev/null; then + gpu_count=$(nvidia-smi --list-gpus | wc -l) + echo "🖥️ Available GPUs: $gpu_count" + if [ "$GPU_COUNT" -gt "$gpu_count" ]; then + echo "⚠️ Warning: Requested $GPU_COUNT GPUs but only $gpu_count available" + echo " Adjusting to use $gpu_count GPUs" + GPU_COUNT=$gpu_count + fi +else + echo "⚠️ Warning: nvidia-smi not found. GPU detection skipped." +fi + +# Warning for non-MEAN pooling types +if [ "$POOLING_TYPE" != "MEAN" ] && [ "$ALLOW_NON_MEAN_CHUNKING" != "true" ]; then + echo "" + echo "⚠️ IMPORTANT: Using $POOLING_TYPE pooling with chunked processing" + echo " This may produce different results than non-chunked processing." + echo " For BERT-type models with bidirectional attention, consider:" + echo " - Using MEAN pooling for mathematically equivalent results" + echo " - Setting ALLOW_NON_MEAN_CHUNKING=true to suppress this warning" + echo "" +fi + +echo "" +echo "🔧 Starting server with enhanced chunked processing configuration..." + +# Build pooler config JSON +POOLER_CONFIG="{\"pooling_type\": \"$POOLING_TYPE\", \"normalize\": true, \"enable_chunked_processing\": true, \"max_embed_len\": ${MAX_EMBED_LEN}" + +# Add allow_non_mean_chunking if needed +if [ "$ALLOW_NON_MEAN_CHUNKING" = "true" ]; then + POOLER_CONFIG="${POOLER_CONFIG}, \"allow_non_mean_chunking\": true" +fi + +POOLER_CONFIG="${POOLER_CONFIG}}" + +# Start vLLM server with enhanced chunked processing +vllm serve "$MODEL_NAME" \ + --tensor-parallel-size "$GPU_COUNT" \ + --enforce-eager \ + --override-pooler-config "$POOLER_CONFIG" \ + --served-model-name ${MODEL_CODE} \ + --task embed \ + --use-v2-block-manager \ + --api-key "$API_KEY" \ + --trust-remote-code \ + --port "$PORT" \ + --host 0.0.0.0 + +echo "" +echo "✅ vLLM Embedding Server started successfully!" +echo "" +echo "📡 Server Information:" +echo " - Base URL: http://localhost:$PORT" +echo " - Model Code: ${MODEL_CODE}" +echo " - API Key: $API_KEY" +echo " - Pooling Strategy: $POOLING_TYPE" +echo "" +echo "🧪 Test the server with:" +echo " python examples/online_serving/openai_embedding_long_text_client.py" +echo "" +echo "📚 Enhanced features enabled:" +echo " ✅ Intelligent pooling type detection and validation" +echo " ✅ Long text chunked processing with proper aggregation" +echo " ✅ Model-specific pooling strategy optimization" +echo " ✅ Enhanced max embedding length (${MAX_EMBED_LEN} tokens)" +echo " ✅ Automatic chunk aggregation (MEAN/CLS/LAST support)" +echo " ✅ OpenAI-compatible API" +echo " ✅ GPU acceleration" +echo "" +echo "🔧 Advanced usage:" +echo " - Set POOLING_TYPE=MEAN|CLS|LAST to override auto-detection" +echo " - Set ALLOW_NON_MEAN_CHUNKING=true for non-MEAN pooling without warnings" +echo " - Set MAX_EMBED_LEN to adjust maximum input length" diff --git a/vllm/config.py b/vllm/config.py index b1f7f9e57a7..344fe0142d2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3240,6 +3240,35 @@ class PoolerConfig: ``math-shepherd-mistral-7b-prm`` model. """ + enable_chunked_processing: Optional[bool] = None + """ + Whether to enable chunked processing for long inputs that exceed the model's + maximum position embeddings. When enabled, long inputs will be split into + chunks, processed separately, and then aggregated using weighted averaging. + This allows embedding models to handle arbitrarily long text without CUDA + errors. Defaults to False. + """ + + max_embed_len: Optional[int] = None + """ + Maximum input length allowed for embedding generation. When set, allows + inputs longer than max_model_len to be accepted for embedding models. + This parameter enables accepting long inputs without requiring + VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable. When an input exceeds + max_embed_len, it will be handled according to the original max_model_len + validation logic. Defaults to None (use max_model_len validation). + """ + + allow_non_mean_chunking: Optional[bool] = None + """ + Whether to allow chunked processing for non-MEAN pooling types without + warnings. By default (None or False), a warning will be shown when using + chunked processing with pooling types other than MEAN, as they may produce + different results than non-chunked processing. Set to True to explicitly + allow and suppress warnings for non-MEAN pooling types. Only applies when + enable_chunked_processing is True. + """ + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index e87decfe636..26eae3b2b8f 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -2,9 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import base64 -from typing import Final, Literal, Optional, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, Final, Literal, Optional, Union, cast import numpy as np +import torch from fastapi import Request from typing_extensions import assert_never, override @@ -12,18 +14,25 @@ from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger +# yapf conflicts with isort for this docstring +# yapf: disable from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest, + EmbeddingCompletionRequest, EmbeddingRequest, EmbeddingResponse, EmbeddingResponseData, ErrorResponse, UsageInfo) from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext, OpenAIServing, - ServeContext) + ServeContext, + TextTokensPrompt) +# yapf: enable from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt +from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, - PoolingRequestOutput) + PoolingRequestOutput, RequestOutput) logger = init_logger(__name__) @@ -133,6 +142,700 @@ def _build_response( usage=usage, ) + def _get_max_position_embeddings(self) -> int: + """Get the model's effective maximum sequence length for chunking. + + This uses the same logic as vLLM's _get_and_verify_max_len to determine + the actual sequence length limit, + considering both model config and tokenizer config. + """ + hf_config = self.model_config.hf_config + + # Start with max_position_embeddings from model config + derived_max_len = getattr(hf_config, 'max_position_embeddings', 512) + + # Get tokenizer config for pooling models (embedding models) + if self.model_config.runner_type == "pooling": + from vllm.transformers_utils.config import try_get_tokenizer_config + tokenizer_config = try_get_tokenizer_config( + self.model_config.tokenizer, + trust_remote_code=self.model_config.trust_remote_code, + revision=self.model_config.tokenizer_revision) + + # Consider model_max_length in tokenizer_config + # (same logic as _get_and_verify_max_len) + if tokenizer_config: + tokenizer_model_max_length = tokenizer_config.get( + 'model_max_length', derived_max_len) + derived_max_len = min(derived_max_len, + tokenizer_model_max_length) + + return int(derived_max_len) + + def _should_use_chunked_processing(self, request) -> bool: + """Check if chunked processing should be used for this request.""" + if not isinstance(request, + (EmbeddingChatRequest, EmbeddingCompletionRequest)): + return False + + pooler_config = getattr(self.model_config, 'pooler_config', None) + if not (pooler_config is not None and getattr( + pooler_config, 'enable_chunked_processing', False)): + return False + + # Check pooling type compatibility for chunked processing + pooling_type = getattr(pooler_config, 'pooling_type', None) + if pooling_type: + pooling_type_upper = pooling_type.upper() + + # For LAST and CLS pooling, chunked processing doesn't make + # semantic sense because only the last/first chunk + # contains the relevant token position + if pooling_type_upper in ['LAST', 'CLS']: + # Check if user explicitly allowed non-mean chunking + allow_non_mean = getattr(pooler_config, + 'allow_non_mean_chunking', False) + if not allow_non_mean: + logger.warning( + "Chunked processing with pooling type '%s' " + "is not recommended as it may produce semantically " + "incorrect results. %s pooling relies on specific " + "token positions that lose their meaning when the " + "sequence is chunked. Consider using MEAN pooling " + "or disable chunked processing. Set " + "'allow_non_mean_chunking: true' ", + "to override this warning.", pooling_type, + pooling_type_upper) + return False # Disable chunked processing by default + else: + logger.info( + "Using chunked processing with %s pooling " + "(explicitly enabled). Note: only the %s chunk " + "will be processed to avoid computational waste.", + pooling_type_upper, + "last" if pooling_type_upper == "LAST" else "first") + + # Warn about non-MEAN pooling types (for other pooling types) + elif pooling_type_upper != 'MEAN': + # Check if user explicitly allowed non-mean chunking + allow_non_mean = getattr(pooler_config, + 'allow_non_mean_chunking', False) + if not allow_non_mean: + logger.warning( + "Chunked processing with pooling type '%s' " + "may produce different results than non-chunked " + "processing. Only MEAN pooling is mathematically " + "equivalent when using weighted averaging aggregation. " + "For other pooling types, different aggregation " + "strategies will be used that approximate the original " + "behavior. Set 'allow_non_mean_chunking: true' " + "in pooler config to suppress this warning.", + pooling_type) + # Still allow it but with warning + else: + logger.info( + "Using chunked processing with pooling type " + "'%s' (explicitly enabled)", pooling_type) + + return True + + def _chunk_token_ids(self, token_ids: list[int], + chunk_size: int) -> list[list[int]]: + """Split token IDs into chunks of specified size.""" + if len(token_ids) <= chunk_size: + return [token_ids] + + chunks = [] + for i in range(0, len(token_ids), chunk_size): + chunk = token_ids[i:i + chunk_size] + chunks.append(chunk) + return chunks + + async def _process_chunked_request( + self, + ctx: EmbeddingServeContext, + original_prompt: TextTokensPrompt, + pooling_params, + trace_headers, + prompt_idx: int, + ) -> list[AsyncGenerator[PoolingRequestOutput, None]]: + """Process a single prompt using chunked processing.""" + generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] + token_ids = original_prompt["prompt_token_ids"] + + # Split into chunks using max_position_embeddings + max_pos_embeddings = self._get_max_position_embeddings() + chunks = self._chunk_token_ids(token_ids, max_pos_embeddings) + + # Check pooling type to optimize chunk processing + pooler_config = getattr(self.model_config, 'pooler_config', None) + pooling_type = getattr(pooler_config, 'pooling_type', 'MEAN') + if pooling_type: + pooling_type = pooling_type.upper() + + # For LAST pooling, only process the last chunk + # For CLS pooling, only process the first chunk + if pooling_type == 'LAST': + chunks_to_process = [chunks[-1]] + chunk_indices = [len(chunks) - 1] + logger.info("LAST pooling: processing only the last chunk") + elif pooling_type == 'CLS': + chunks_to_process = [chunks[0]] + chunk_indices = [0] + logger.info("CLS pooling: processing only the first chunk") + else: + # For MEAN and other pooling types, process all chunks + chunks_to_process = chunks + chunk_indices = list(range(len(chunks))) + logger.info("Using chunked processing for MEAN pooling") + + for i, (chunk_idx, chunk_tokens) in enumerate( + zip(chunk_indices, chunks_to_process)): + # Create a request ID for this chunk + chunk_request_id = (f"{ctx.request_id}-prompt-{prompt_idx}-" + f"chunk-{chunk_idx}") + + # Create engine prompt for this chunk + chunk_engine_prompt = EngineTokensPrompt( + prompt_token_ids=chunk_tokens) + + # Create chunk request prompt for logging + chunk_text = "" + chunk_request_prompt = TextTokensPrompt( + prompt=chunk_text, prompt_token_ids=chunk_tokens) + + # Log the chunk + self._log_inputs(chunk_request_id, + chunk_request_prompt, + params=pooling_params, + lora_request=ctx.lora_request, + prompt_adapter_request=ctx.prompt_adapter_request) + + # Create generator for this chunk + generator = self.engine_client.encode( + chunk_engine_prompt, + pooling_params, + chunk_request_id, + lora_request=ctx.lora_request, + trace_headers=trace_headers, + priority=getattr(ctx.request, "priority", 0), + ) + + generators.append(generator) + + return generators + + def _validate_input( + self, + request, + input_ids: list[int], + input_text: str, + ) -> TextTokensPrompt: + """Override to support chunked processing for embedding requests.""" + token_num = len(input_ids) + + # Note: EmbeddingRequest doesn't have max_tokens + if isinstance(request, + (EmbeddingChatRequest, EmbeddingCompletionRequest)): + # Check if chunked processing is enabled for pooling models + pooler_config = getattr(self.model_config, 'pooler_config', None) + enable_chunked = (pooler_config is not None and getattr( + pooler_config, 'enable_chunked_processing', False)) + + # Get max_embed_len from pooler config if set + max_embed_len = (pooler_config.max_embed_len if pooler_config + and pooler_config.max_embed_len else None) + + # Use max_position_embeddings for chunked processing decisions + max_pos_embeddings = self._get_max_position_embeddings() + + # Determine the effective max length for validation + if max_embed_len is not None: + # Use max_embed_len for validation instead of max_model_len + effective_max_len = max_embed_len + validation_error_msg = ( + f"This model's maximum embedding input length is " + f"{max_embed_len} tokens. However, you requested " + f"{token_num} tokens in the input for embedding " + f"generation. Please reduce the length of the input.") + else: + # Fall back to max_model_len validation (original behavior) + effective_max_len = self.max_model_len + validation_error_msg = ( + f"This model's maximum context length is " + f"{self.max_model_len} tokens. However, you requested " + f"{token_num} tokens in the input for embedding " + f"generation. Please reduce the length of the input.") + + # Check if input exceeds effective max length + if token_num > effective_max_len: + raise ValueError(validation_error_msg) + + # Check for chunked processing + # when exceeding max_position_embeddings + if token_num > max_pos_embeddings: + if enable_chunked: + # Allow long inputs when chunked processing is enabled + logger.info( + "Input length %s exceeds max_position_embeddings " + "%s, will use chunked processing", token_num, + max_pos_embeddings) + else: + raise ValueError( + f"This model's maximum position embeddings length is " + f"{max_pos_embeddings} tokens. However, you requested " + f"{token_num} tokens in the input for embedding " + f"generation. Please reduce the length of the input or " + f"enable chunked processing.") + + return TextTokensPrompt(prompt=input_text, + prompt_token_ids=input_ids) + + # For other request types, use the parent's implementation + return super()._validate_input(request, input_ids, input_text) + + def _is_text_tokens_prompt(self, prompt) -> bool: + """Check if a prompt is a TextTokensPrompt (has prompt_token_ids).""" + return (isinstance(prompt, dict) and "prompt_token_ids" in prompt + and "prompt_embeds" not in prompt) + + async def _prepare_generators( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + """Override to support chunked processing.""" + ctx = cast(EmbeddingServeContext, ctx) + generators: list[AsyncGenerator[Union[RequestOutput, + PoolingRequestOutput], + None]] = [] + + try: + trace_headers = (None if ctx.raw_request is None else await + self._get_trace_headers(ctx.raw_request.headers)) + + if not hasattr(ctx.request, "to_pooling_params"): + return self.create_error_response( + "Request type does not support pooling parameters") + + pooling_params = ctx.request.to_pooling_params() + + if ctx.engine_prompts is None: + return self.create_error_response( + "Engine prompts not available") + + if ctx.request_prompts is None: + return self.create_error_response( + "Request prompts not available") + + # Check if we should use chunked processing + use_chunked = self._should_use_chunked_processing(ctx.request) + + for i, engine_prompt in enumerate(ctx.engine_prompts): + request_prompt = ctx.request_prompts[i] + + # Check if this specific prompt needs chunked processing + max_pos_embeddings = self._get_max_position_embeddings() + if (use_chunked + and self._is_text_tokens_prompt(request_prompt)): + # Cast to TextTokensPrompt since we've + # verified prompt_token_ids + text_tokens_prompt = cast(TextTokensPrompt, request_prompt) + if len(text_tokens_prompt["prompt_token_ids"] + ) > max_pos_embeddings: + # Use chunked processing for this prompt + chunk_generators = await self._process_chunked_request( + ctx, text_tokens_prompt, pooling_params, + trace_headers, i) + generators.extend(chunk_generators) + continue + + # Normal processing for short prompts or non-token prompts + request_id_item = f"{ctx.request_id}-{i}" + + self._log_inputs( + request_id_item, + request_prompt, + params=pooling_params, + lora_request=ctx.lora_request, + prompt_adapter_request=ctx.prompt_adapter_request) + + # Mypy has an existing bug related to inferring the variance + # of TypedDicts with `builtins.enumerate`: + # https://github.com/python/mypy/issues/8586#issuecomment-2867698435 + engine_prompt = cast( + Union[EngineTokensPrompt, EngineEmbedsPrompt], + engine_prompt) + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=ctx.lora_request, + trace_headers=trace_headers, + priority=getattr(ctx.request, "priority", 0), + ) + + generators.append(generator) + + from vllm.utils import merge_async_iterators + ctx.result_generator = merge_async_iterators(*generators) + + return None + + except Exception as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + async def _collect_batch( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + """Collect and aggregate batch results + with support for chunked processing. + + For chunked requests, performs online aggregation to + minimize memory usage. + For regular requests, collects results normally. + """ + ctx = cast(EmbeddingServeContext, ctx) + try: + if ctx.engine_prompts is None: + return self.create_error_response( + "Engine prompts not available") + + if ctx.request_prompts is None: + return self.create_error_response( + "Request prompts not available") + + if ctx.result_generator is None: + return self.create_error_response( + "Result generator not available") + + # Check if we used chunked processing + use_chunked = self._should_use_chunked_processing(ctx.request) + + if use_chunked: + # Online aggregation for chunked requests to + # minimize memory usage + # Track aggregation state for each prompt + prompt_aggregators: dict[int, dict[str, Any]] = {} + short_prompts_results: dict[int, PoolingRequestOutput] = {} + + async for result_idx, result in ctx.result_generator: + if "-chunk-" in result.request_id: + # Extract prompt_idx from chunked request_id + parts = result.request_id.split("-") + try: + prompt_idx = int(parts[parts.index("prompt") + 1]) + + # Initialize aggregator for this prompt if needed + if prompt_idx not in prompt_aggregators: + # Get pooling type to determine + # aggregation strategy + pooler_config = getattr( + self.model_config, 'pooler_config', None) + pooling_type = getattr(pooler_config, + 'pooling_type', 'MEAN') + if pooling_type: + pooling_type = pooling_type.upper() + + prompt_aggregators[prompt_idx] = { + 'pooling_type': + pooling_type, + 'weighted_sum': + None, + 'total_weight': + 0, + 'first_result': + None, + 'last_result': + None, + 'chunk_count': + 0, + 'request_id': + result.request_id.split("-chunk-")[0] + } + + aggregator = prompt_aggregators[prompt_idx] + pooling_type = aggregator['pooling_type'] + + # Handle different pooling types with + # online aggregation + if pooling_type == 'MEAN': + # Online weighted averaging + # Ensure result is PoolingRequestOutput + # for embedding processing + if not isinstance(result, + PoolingRequestOutput): + return self.create_error_response( + f"Expected PoolingRequestOutput for " + f"chunked embedding, got " + f"{type(result).__name__}") + + embedding_data = result.outputs.data + if not isinstance(embedding_data, + torch.Tensor): + embedding_data = torch.tensor( + embedding_data, dtype=torch.float32) + + if result.prompt_token_ids is None: + return self.create_error_response( + "prompt_token_ids cannot be None for " + "chunked processing") + weight = len(result.prompt_token_ids) + + weighted_embedding = embedding_data.to( + dtype=torch.float32) * weight + + if aggregator['weighted_sum'] is None: + # First chunk + aggregator[ + 'weighted_sum'] = weighted_embedding + else: + # Accumulate + current_sum = aggregator['weighted_sum'] + if isinstance(current_sum, torch.Tensor): + aggregator['weighted_sum'] = ( + current_sum + weighted_embedding) + + total_weight = aggregator['total_weight'] + if isinstance(total_weight, (int, float)): + aggregator['total_weight'] = ( + total_weight + weight) + + elif pooling_type == 'LAST': + # Keep only the + # last result (highest chunk index) + if not isinstance(result, + PoolingRequestOutput): + return self.create_error_response( + f"Expected PoolingRequestOutput for " + f"chunked embedding, got " + f"{type(result).__name__}") + + chunk_idx = int(parts[parts.index("chunk") + + 1]) + last_chunk_idx = aggregator.get( + 'last_chunk_idx', -1) + # Ensure last_chunk_idx is an integer + # for comparison + if not isinstance(last_chunk_idx, int): + last_chunk_idx = -1 + if (aggregator['last_result'] is None + or chunk_idx > last_chunk_idx): + aggregator['last_result'] = result + aggregator['last_chunk_idx'] = chunk_idx + + elif pooling_type == 'CLS': + # Keep only the first result (chunk index 0) + if not isinstance(result, + PoolingRequestOutput): + return self.create_error_response( + f"Expected PoolingRequestOutput for " + f"chunked embedding, got " + f"{type(result).__name__}") + + chunk_idx = int(parts[parts.index("chunk") + + 1]) + if chunk_idx == 0: + aggregator['first_result'] = result + + chunk_count = aggregator['chunk_count'] + if isinstance(chunk_count, int): + aggregator['chunk_count'] = chunk_count + 1 + + except (ValueError, IndexError): + return self.create_error_response( + f"Invalid chunk request ID format: " + f"{result.request_id}") + else: + # Non-chunked result + try: + prompt_idx = int(result.request_id.split("-")[-1]) + short_prompts_results[prompt_idx] = cast( + PoolingRequestOutput, result) + except ValueError: + return self.create_error_response( + f"Invalid request ID format: " + f"{result.request_id}") + + # Build final result batch + final_res_batch = [] + + for prompt_idx, request_prompt in enumerate( + ctx.request_prompts): + if prompt_idx in prompt_aggregators: + # Finalize aggregation for this chunked prompt + aggregator = prompt_aggregators[prompt_idx] + pooling_type = aggregator['pooling_type'] + + if pooling_type == 'MEAN': + # Finalize weighted average + weighted_sum = aggregator['weighted_sum'] + total_weight = aggregator['total_weight'] + if (weighted_sum is not None + and isinstance(weighted_sum, torch.Tensor) + and isinstance(total_weight, (int, float)) + and total_weight > 0): + final_embedding = weighted_sum / total_weight + + # Create aggregated result + from vllm.outputs import PoolingOutput + aggregated_output = PoolingOutput( + data=final_embedding) + + # Get original prompt token ids + if self._is_text_tokens_prompt(request_prompt): + text_tokens_prompt = cast( + TextTokensPrompt, request_prompt) + original_token_ids = text_tokens_prompt[ + "prompt_token_ids"] + else: + return self.create_error_response( + f"Chunked prompt {prompt_idx} is not a " + f"text tokens prompt") + + # Ensure request_id is string + request_id = aggregator['request_id'] + if not isinstance(request_id, str): + return self.create_error_response( + f"Invalid request_id type: " + f"{type(request_id)}") + + aggregated_result = PoolingRequestOutput( + request_id=request_id, + outputs=aggregated_output, + prompt_token_ids=original_token_ids, + finished=True, + ) + final_res_batch.append(aggregated_result) + else: + return self.create_error_response( + f"No valid aggregation data for prompt " + f"{prompt_idx}") + + elif pooling_type == 'LAST': + if aggregator['last_result'] is not None: + # Use the last chunk result + last_result = aggregator['last_result'] + if not isinstance(last_result, + PoolingRequestOutput): + return self.create_error_response( + f"Expected PoolingRequestOutput for " + f"last_result, got " + f"{type(last_result).__name__}") + + if self._is_text_tokens_prompt(request_prompt): + text_tokens_prompt = cast( + TextTokensPrompt, request_prompt) + original_token_ids = text_tokens_prompt[ + "prompt_token_ids"] + + # Ensure request_id is string + request_id = aggregator['request_id'] + if not isinstance(request_id, str): + return self.create_error_response( + f"Invalid request_id type: " + f"{type(request_id)}") + + aggregated_result = PoolingRequestOutput( + request_id=request_id, + outputs=last_result.outputs, + prompt_token_ids=original_token_ids, + finished=True, + ) + final_res_batch.append(aggregated_result) + else: + return self.create_error_response( + f"Chunked prompt {prompt_idx} is not a " + f"text tokens prompt") + else: + return self.create_error_response( + f"No LAST result found for prompt " + f"{prompt_idx}") + + elif pooling_type == 'CLS': + if aggregator['first_result'] is not None: + # Use the first chunk result + first_result = aggregator['first_result'] + if not isinstance(first_result, + PoolingRequestOutput): + return self.create_error_response( + f"Expected PoolingRequestOutput for " + f"first_result, got " + f"{type(first_result).__name__}") + + if self._is_text_tokens_prompt(request_prompt): + text_tokens_prompt = cast( + TextTokensPrompt, request_prompt) + original_token_ids = text_tokens_prompt[ + "prompt_token_ids"] + + # Ensure request_id is string + request_id = aggregator['request_id'] + if not isinstance(request_id, str): + return self.create_error_response( + f"Invalid request_id type: " + f"{type(request_id)}") + + aggregated_result = PoolingRequestOutput( + request_id=request_id, + outputs=first_result.outputs, + prompt_token_ids=original_token_ids, + finished=True, + ) + final_res_batch.append(aggregated_result) + else: + return self.create_error_response( + f"Chunked prompt {prompt_idx} is not a " + f"text tokens prompt") + else: + return self.create_error_response( + f"No CLS result found for prompt " + f"{prompt_idx}") + else: + return self.create_error_response( + f"Unsupported pooling type for chunked " + f"processing: {pooling_type}") + + elif prompt_idx in short_prompts_results: + # This was a short prompt + final_res_batch.append( + short_prompts_results[prompt_idx]) + else: + return self.create_error_response( + f"Result not found for prompt {prompt_idx}") + + ctx.final_res_batch = cast( + list[Union[RequestOutput, PoolingRequestOutput]], + final_res_batch) + else: + # Normal processing for non-chunked requests + num_prompts = len(ctx.engine_prompts) + normal_final_res_batch: list[ + Optional[PoolingRequestOutput]] = [None] * num_prompts + + async for result_idx, result in ctx.result_generator: + if result_idx < num_prompts: + # Cast to PoolingRequestOutput for embedding results + normal_final_res_batch[result_idx] = cast( + PoolingRequestOutput, result) + + if None in normal_final_res_batch: + return self.create_error_response( + "Failed to generate results for all prompts") + + final_results = [ + res for res in normal_final_res_batch if res is not None + ] + ctx.final_res_batch = cast( + list[Union[RequestOutput, PoolingRequestOutput]], + final_results) + + return None + + except Exception as e: + return self.create_error_response(str(e)) + class OpenAIServingEmbedding(EmbeddingMixin): request_id_prefix = "embd"