Skip to content

[Model] Add support for Jina Embeddings V4 #20802

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

sigridjineth
Copy link

@sigridjineth sigridjineth commented Jul 11, 2025

Purpose

This PR adds support for the Jina Embeddings V4 model
(jinaai/jina-embeddings-v4-vllm-retrieval) in vLLM, enabling multimodal embeddings for text and image inputs.

FIX #20463

Test Plan

  • Run the comprehensive test suite for Jina V4
    pytest
    tests/models/multimodal/test_jina_embeddings_v4.py -v

## Documentation Update

  The following documentation has been added:

  1. Model Registry (vllm/model_executor/models/registry.py): Added JinaVLForEmbedding entry
  2. Supported Models (docs/source/models/supported_models.rst): Added Jina Embeddings V4 to the multimodal embeddings section
  3. Example Script (examples/offline_inference/jina_embeddings_v4.py):
  4. Implementation Guide  (JINA_EMBEDDINGS_V4_IMPLEMENTATION.md): Detailed
  technical documentation

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @sigridjineth, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands vLLM's capabilities by integrating native support for the Jina Embeddings V4 multimodal model. My work ensures that users can now efficiently generate embeddings for both text and image inputs, benefiting from highly optimized processing, robust error handling, and comprehensive testing to guarantee accuracy and stability.

Highlights

  • Jina Embeddings V4 Integration: I've added full, production-ready support for the Jina Embeddings V4 model (jinaai/jina-embeddings-v4-vllm-retrieval), enabling both text and image multimodal embeddings within vLLM.
  • Optimized Multimodal Pooling: The implementation includes thread-safe, token-type-aware pooling, leveraging optimized Triton CUDA kernels for efficient vision token extraction, with a robust PyTorch fallback.
  • Robustness and Observability: I've incorporated comprehensive error handling, including graceful fallback mechanisms and OOM recovery during pooling, alongside full observability integration for performance metrics.
  • Comprehensive Testing & Examples: New unit tests, a validation benchmark against HuggingFace, and an offline inference example script have been added to ensure accuracy and demonstrate usage.
  • Documentation Updates: The model has been registered in the model registry, added to the supported models documentation, and a detailed implementation guide has been provided.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) new-model Requests to new models performance Performance-related issues labels Jul 11, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR adds production-ready support for the Jina Embeddings V4 model. I've identified a bug in the tests, a performance issue in the core implementation, and some areas for code improvement in the example and validation scripts.

def test_seq_data_access(self):
"""Test that token IDs are properly accessible via seq_data."""
# Create mock sequence data
prompt_tokens = array('l', [101, 102, VISION_START_TOKEN_ID, 103, VISION_END_TOKEN_ID, 104])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The prompt_tokens data in TestPoolingMetadataIntegration.test_seq_data_access seems incorrect. To properly test the pooling of multiple vision tokens, include a mock 'middle' token that also falls within the vision token ID range. Clarify the comment on line 282 to be more specific.

Suggested change
prompt_tokens = array('l', [101, 102, VISION_START_TOKEN_ID, 103, VISION_END_TOKEN_ID, 104])
prompt_tokens = array('l', [101, 102, VISION_START_TOKEN_ID, VISION_START_TOKEN_ID, VISION_END_TOKEN_ID, 104])

output = seq_states.mean(dim=0)

# Normalize
output = F.normalize(output, p=2, dim=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In _apply_vision_pooling_optimized, if has_vision is true but the Triton kernel doesn't find any vision tokens, the output tensor will be all zeros. Normalizing a zero vector with F.normalize will result in NaNs. Add a check to ensure the Triton path is as robust as the PyTorch path.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 341 to 345
finally:
# Clean up any temporary tensors
if 'token_tensor' in locals():
del token_tensor
torch.cuda.empty_cache()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of torch.cuda.empty_cache() in the finally block of the pooler method can introduce significant performance overhead. Consider removing this finally block to avoid performance degradation.

Suggested change
finally:
# Clean up any temporary tensors
if 'token_tensor' in locals():
del token_tensor
torch.cuda.empty_cache()
finally:
# Rely on Python's garbage collector for releasing tensors.
# torch.cuda.empty_cache() is a blocking and expensive operation
# that should be used sparingly.
pass

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 96 to 97
(input_ids >= 151652) &
(input_ids <= 151653)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The magic numbers 151652 and 151653 are used for vision token ID checks. Define them as VISION_START_TOKEN_ID and VISION_END_TOKEN_ID constants, respectively, to improve readability and maintain consistency with other files in this PR.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 29 to 46
img_start_pos = torch.where(
torch.tensor(output.prompt_token_ids) == VISION_START_TOKEN_ID
)[0][0]
img_end_pos = torch.where(
torch.tensor(output.prompt_token_ids) == VISION_END_TOKEN_ID
)[0][0]
embeddings_tensor = output.outputs.data.detach().clone()[
img_start_pos : img_end_pos + 1
]
else:
# For text-only inputs, use all token embeddings
embeddings_tensor = output.outputs.data.detach().clone()

# Pool and normalize embeddings
pooled_output = (
embeddings_tensor.sum(dim=0, dtype=torch.float32)
/ embeddings_tensor.shape[0]
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider using list.index() for finding token positions and torch.mean() for computing tensor means for better efficiency and code clarity.

            img_start_pos = output.prompt_token_ids.index(VISION_START_TOKEN_ID)
            img_end_pos = output.prompt_token_ids.index(VISION_END_TOKEN_ID)
            embeddings_tensor = output.outputs.data.detach().clone()[
                img_start_pos : img_end_pos + 1
            ]
        else:
            # For text-only inputs, use all token embeddings
            embeddings_tensor = output.outputs.data.detach().clone()
        
        # Pool and normalize embeddings
        pooled_output = embeddings_tensor.mean(dim=0, dtype=torch.float32)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contributing! Can you add this model to the test registry and supported models page?



# Triton kernel for optimized vision token extraction
if HAS_TRITON:
Copy link
Member

@DarkLight1337 DarkLight1337 Jul 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much is the performance increase using triton that this additional complexity is justified? cc @Isotr0py @imkero

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would provide Triton performance benchmarks after finshing up some tasks in the pr

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this triton kernel is only used in pooler, I think the performance improvement will be very little. But it would be best to have a performance benchmarks first.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you perform benchmarking on this?

Comment on lines 356 to 363
try:
loader = AutoWeightsLoader(self)
loaded_weights = loader.load_weights(weights, mapper=self.weight_mapper)
logger.info(f"Successfully loaded {len(loaded_weights)} weight tensors")
return loaded_weights
except Exception as e:
logger.error(f"Error loading weights: {e}")
raise
Copy link
Collaborator

@Isotr0py Isotr0py Jul 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
try:
loader = AutoWeightsLoader(self)
loaded_weights = loader.load_weights(weights, mapper=self.weight_mapper)
logger.info(f"Successfully loaded {len(loaded_weights)} weight tensors")
return loaded_weights
except Exception as e:
logger.error(f"Error loading weights: {e}")
raise
loader = AutoWeightsLoader(self)
loaded_weights = loader.load_weights(weights, mapper=self.weight_mapper)
return loaded_weights

Please clean up the try: ... except: ... statement used in debug.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

amended 1b594aa



# Triton kernel for optimized vision token extraction
if HAS_TRITON:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this triton kernel is only used in pooler, I think the performance improvement will be very little. But it would be best to have a performance benchmarks first.

hidden_states, token_ids_list, prompt_lens
)
except RuntimeError as e:
if "out of memory" in str(e).lower():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest controlling this via config. Otherwise, we have to keep handling OOM exception in corner cases where OOM only occurs in self._apply_vision_pooling_optimized but not in self._apply_vision_pooling_pytorch which greatly harms the performance.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

amended 48c4ea4

# Validate lengths match
if len(token_ids_list) != len(prompt_lens):
logger.error(f"Mismatch: {len(token_ids_list)} sequences vs {len(prompt_lens)} lengths")
return self._base_pooler(hidden_states, pooling_metadata)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should raise an assertion error IMO

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

amended 714e283


# Normalize and handle potential NaNs by replacing with zeros
output = F.normalize(output, p=2, dim=-1)
output = torch.nan_to_num(output)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this is probably unnecessary as F.normalize has eps parameter to avoid division by zero

Copy link
Author

@sigridjineth sigridjineth Jul 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup thanks d0d7b26

Comment on lines 12 to 19
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
triton = None
tl = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
triton = None
tl = None
from vllm.triton_utils import tl, triton

Use this to automatically trigger triton placeholder.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done caea1fe

Comment on lines 48 to 59
@triton.jit
def extract_vision_tokens_kernel(
hidden_states_ptr,
token_ids_ptr,
output_ptr,
seq_start,
seq_len,
hidden_size,
vision_start_id: tl.constexpr,
vision_end_id: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like putting triton kernel in model implementation, we should move this to pooler.py or somewhere else if the performance improvement is significant.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done caea1fe

@sigridjineth sigridjineth force-pushed the jina-support branch 2 times, most recently from 54839c7 to 4eb5e88 Compare July 11, 2025 14:38
@sigridjineth
Copy link
Author

@Isotr0py @DarkLight1337 do review if more changes needed if you think so

@DarkLight1337
Copy link
Member

Sorry for the delay, can you merge from main and fix pre-commit?

@DarkLight1337 DarkLight1337 added this to the v0.10.0 milestone Jul 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) new-model Requests to new models performance Performance-related issues
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[New Model]: jinaai/jina-embeddings-v4
3 participants