-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
[Model] Add support for Jina Embeddings V4 #20802
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @sigridjineth, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly expands vLLM's capabilities by integrating native support for the Jina Embeddings V4 multimodal model. My work ensures that users can now efficiently generate embeddings for both text and image inputs, benefiting from highly optimized processing, robust error handling, and comprehensive testing to guarantee accuracy and stability.
Highlights
- Jina Embeddings V4 Integration: I've added full, production-ready support for the Jina Embeddings V4 model (
jinaai/jina-embeddings-v4-vllm-retrieval
), enabling both text and image multimodal embeddings within vLLM. - Optimized Multimodal Pooling: The implementation includes thread-safe, token-type-aware pooling, leveraging optimized Triton CUDA kernels for efficient vision token extraction, with a robust PyTorch fallback.
- Robustness and Observability: I've incorporated comprehensive error handling, including graceful fallback mechanisms and OOM recovery during pooling, alongside full observability integration for performance metrics.
- Comprehensive Testing & Examples: New unit tests, a validation benchmark against HuggingFace, and an offline inference example script have been added to ensure accuracy and demonstrate usage.
- Documentation Updates: The model has been registered in the model registry, added to the supported models documentation, and a detailed implementation guide has been provided.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This PR adds production-ready support for the Jina Embeddings V4 model. I've identified a bug in the tests, a performance issue in the core implementation, and some areas for code improvement in the example and validation scripts.
def test_seq_data_access(self): | ||
"""Test that token IDs are properly accessible via seq_data.""" | ||
# Create mock sequence data | ||
prompt_tokens = array('l', [101, 102, VISION_START_TOKEN_ID, 103, VISION_END_TOKEN_ID, 104]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The prompt_tokens
data in TestPoolingMetadataIntegration.test_seq_data_access
seems incorrect. To properly test the pooling of multiple vision tokens, include a mock 'middle' token that also falls within the vision token ID range. Clarify the comment on line 282 to be more specific.
prompt_tokens = array('l', [101, 102, VISION_START_TOKEN_ID, 103, VISION_END_TOKEN_ID, 104]) | |
prompt_tokens = array('l', [101, 102, VISION_START_TOKEN_ID, VISION_START_TOKEN_ID, VISION_END_TOKEN_ID, 104]) |
output = seq_states.mean(dim=0) | ||
|
||
# Normalize | ||
output = F.normalize(output, p=2, dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mistake. fixed here. d7d6b60#diff-bc18208ee0cfb6279e0c9ef1df4692b32b16d1d9a8a1eb650867372b9ac4612a
finally: | ||
# Clean up any temporary tensors | ||
if 'token_tensor' in locals(): | ||
del token_tensor | ||
torch.cuda.empty_cache() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of torch.cuda.empty_cache()
in the finally
block of the pooler
method can introduce significant performance overhead. Consider removing this finally
block to avoid performance degradation.
finally: | |
# Clean up any temporary tensors | |
if 'token_tensor' in locals(): | |
del token_tensor | |
torch.cuda.empty_cache() | |
finally: | |
# Rely on Python's garbage collector for releasing tensors. | |
# torch.cuda.empty_cache() is a blocking and expensive operation | |
# that should be used sparingly. | |
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(input_ids >= 151652) & | ||
(input_ids <= 151653) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
img_start_pos = torch.where( | ||
torch.tensor(output.prompt_token_ids) == VISION_START_TOKEN_ID | ||
)[0][0] | ||
img_end_pos = torch.where( | ||
torch.tensor(output.prompt_token_ids) == VISION_END_TOKEN_ID | ||
)[0][0] | ||
embeddings_tensor = output.outputs.data.detach().clone()[ | ||
img_start_pos : img_end_pos + 1 | ||
] | ||
else: | ||
# For text-only inputs, use all token embeddings | ||
embeddings_tensor = output.outputs.data.detach().clone() | ||
|
||
# Pool and normalize embeddings | ||
pooled_output = ( | ||
embeddings_tensor.sum(dim=0, dtype=torch.float32) | ||
/ embeddings_tensor.shape[0] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider using list.index()
for finding token positions and torch.mean()
for computing tensor means for better efficiency and code clarity.
img_start_pos = output.prompt_token_ids.index(VISION_START_TOKEN_ID)
img_end_pos = output.prompt_token_ids.index(VISION_END_TOKEN_ID)
embeddings_tensor = output.outputs.data.detach().clone()[
img_start_pos : img_end_pos + 1
]
else:
# For text-only inputs, use all token embeddings
embeddings_tensor = output.outputs.data.detach().clone()
# Pool and normalize embeddings
pooled_output = embeddings_tensor.mean(dim=0, dtype=torch.float32)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
34f3e7f
to
d7d6b60
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for contributing! Can you add this model to the test registry and supported models page?
|
||
|
||
# Triton kernel for optimized vision token extraction | ||
if HAS_TRITON: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would provide Triton performance benchmarks after finshing up some tasks in the pr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this triton kernel is only used in pooler, I think the performance improvement will be very little. But it would be best to have a performance benchmarks first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you perform benchmarking on this?
try: | ||
loader = AutoWeightsLoader(self) | ||
loaded_weights = loader.load_weights(weights, mapper=self.weight_mapper) | ||
logger.info(f"Successfully loaded {len(loaded_weights)} weight tensors") | ||
return loaded_weights | ||
except Exception as e: | ||
logger.error(f"Error loading weights: {e}") | ||
raise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try: | |
loader = AutoWeightsLoader(self) | |
loaded_weights = loader.load_weights(weights, mapper=self.weight_mapper) | |
logger.info(f"Successfully loaded {len(loaded_weights)} weight tensors") | |
return loaded_weights | |
except Exception as e: | |
logger.error(f"Error loading weights: {e}") | |
raise | |
loader = AutoWeightsLoader(self) | |
loaded_weights = loader.load_weights(weights, mapper=self.weight_mapper) | |
return loaded_weights |
Please clean up the try: ... except: ...
statement used in debug.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
amended 1b594aa
|
||
|
||
# Triton kernel for optimized vision token extraction | ||
if HAS_TRITON: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this triton kernel is only used in pooler, I think the performance improvement will be very little. But it would be best to have a performance benchmarks first.
hidden_states, token_ids_list, prompt_lens | ||
) | ||
except RuntimeError as e: | ||
if "out of memory" in str(e).lower(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest controlling this via config. Otherwise, we have to keep handling OOM exception in corner cases where OOM only occurs in self._apply_vision_pooling_optimized
but not in self._apply_vision_pooling_pytorch
which greatly harms the performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
amended 48c4ea4
# Validate lengths match | ||
if len(token_ids_list) != len(prompt_lens): | ||
logger.error(f"Mismatch: {len(token_ids_list)} sequences vs {len(prompt_lens)} lengths") | ||
return self._base_pooler(hidden_states, pooling_metadata) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should raise an assertion error IMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
amended 714e283
|
||
# Normalize and handle potential NaNs by replacing with zeros | ||
output = F.normalize(output, p=2, dim=-1) | ||
output = torch.nan_to_num(output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually this is probably unnecessary as F.normalize
has eps
parameter to avoid division by zero
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup thanks d0d7b26
try: | ||
import triton | ||
import triton.language as tl | ||
HAS_TRITON = True | ||
except ImportError: | ||
HAS_TRITON = False | ||
triton = None | ||
tl = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try: | |
import triton | |
import triton.language as tl | |
HAS_TRITON = True | |
except ImportError: | |
HAS_TRITON = False | |
triton = None | |
tl = None | |
from vllm.triton_utils import tl, triton |
Use this to automatically trigger triton placeholder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done caea1fe
@triton.jit | ||
def extract_vision_tokens_kernel( | ||
hidden_states_ptr, | ||
token_ids_ptr, | ||
output_ptr, | ||
seq_start, | ||
seq_len, | ||
hidden_size, | ||
vision_start_id: tl.constexpr, | ||
vision_end_id: tl.constexpr, | ||
BLOCK_SIZE: tl.constexpr, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like putting triton kernel in model implementation, we should move this to pooler.py
or somewhere else if the performance improvement is significant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done caea1fe
54839c7
to
4eb5e88
Compare
4eb5e88
to
caea1fe
Compare
@Isotr0py @DarkLight1337 do review if more changes needed if you think so |
Sorry for the delay, can you merge from main and fix pre-commit? |
Purpose
This PR adds support for the Jina Embeddings V4 model
(
jinaai/jina-embeddings-v4-vllm-retrieval
) in vLLM, enabling multimodal embeddings for text and image inputs.FIX #20463
Test Plan