Skip to content

[Model] Add support for Jina Embeddings V4 #20802

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
9fbc0e9
feat: jina support
sigridjineth Jul 11, 2025
eea8462
refactor: fail fast
sigridjineth Jul 11, 2025
5e247e9
refactor: exceptions
sigridjineth Jul 11, 2025
9be40b2
refactor: improve jina embeddings v4 model
sigridjineth Jul 11, 2025
64c06c7
refactor: oom
sigridjineth Jul 11, 2025
56b7409
refactor: Validate lengths match
sigridjineth Jul 11, 2025
bef3df2
refactor: normalize
sigridjineth Jul 11, 2025
efa8b04
refactor: normalize
sigridjineth Jul 11, 2025
0fe30f8
refactor: review
sigridjineth Jul 11, 2025
062a156
refactor: prehook commits
sigridjineth Jul 16, 2025
edfe91a
fix: Apply isort formatting to jina_embeddings_v4.py
Jul 16, 2025
5d12bd4
[ci skip-hooks] Formatting attempt(s)
Jul 16, 2025
27b28f7
fix: Resolve yapf/isort conflict with disable comments
Jul 17, 2025
3bdbd17
refactor: accept review
Jul 17, 2025
fafd668
refactor: address review feedback for Jina embeddings V4
Jul 18, 2025
0c3f1bd
refactor: import HAS_TRITON from triton_utils instead of local defini…
Jul 18, 2025
5c45015
refactor: rename example file to follow existing embedding pattern
Jul 18, 2025
9d34781
Merge remote-tracking branch 'origin/main' into jina-support
Jul 18, 2025
8e0578a
refactor: update JinaVLForEmbedding to comply with new pooling archit…
Jul 18, 2025
eb1497e
refactor: use pooler utility functions to avoid duplicate code
Jul 18, 2025
1b4f405
refactor: address maintainer review comments for JinaVLPooler
Jul 18, 2025
702fd16
perf: optimize vision token detection using torch.isin
Jul 18, 2025
5114a3c
fix: introducing dedicated VisionPooler class
Jul 18, 2025
6b501b2
feat: add vision pooling support for jina embeddings v4
Jul 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions tests/models/pooling/test_jina_embeddings_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,40 @@ def test_vision_only_pooling(self, model):
# embeddings should be very similar despite different text
similarity = torch.dot(emb1, emb2).item()
assert similarity > 0.99 # Should be nearly identical


class TestVisionPooler:
"""Test the VisionPooler class."""

def test_vision_pooler(self):
"""Test that the VisionPooler correctly pools vision tokens."""
from vllm.config import ModelConfig
from vllm.model_executor.layers.pooler import VisionPooler
from vllm.pooling_params import PoolingParams
from vllm.v1.pool.metadata import PoolingMetadata

model_config = ModelConfig(model_name, task="embed")
model_config.hf_config.vision_start_token_id = VISION_START_TOKEN_ID
model_config.hf_config.vision_end_token_id = VISION_END_TOKEN_ID
model_config.hidden_size = 4

pooler = VisionPooler(model_config)

hidden_states = torch.randn(10, 4)
prompt_token_ids = torch.tensor([[
1, 2, VISION_START_TOKEN_ID, 4, VISION_END_TOKEN_ID, 6, 7, 8, 9, 10
]])
prompt_lens = torch.tensor([10])

pooling_metadata = PoolingMetadata(prompt_lens=prompt_lens,
prompt_token_ids=prompt_token_ids,
pooling_params=[PoolingParams()])

output = pooler.forward(hidden_states, pooling_metadata)

vision_tokens = hidden_states[2:5]
expected_output = vision_tokens.mean(dim=0)

assert torch.allclose(output.outputs[0].data,
expected_output,
atol=1e-5)
5 changes: 3 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3256,9 +3256,10 @@ def get_limit_per_prompt(self, modality: str) -> int:
@config
@dataclass
class PoolerConfig:
"""Controls the behavior of output pooling in pooling models."""
"""Configuration for the pooler."""

pooling_type: Optional[str] = None
pooling_type: Optional[Literal["last", "all", "cls", "step", "mean",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the pooling type here is supposed to be upper case

"vision"]] = None
"""
The pooling method of the pooling model. This should be a key in
[`vllm.model_executor.layers.pooler.PoolingType`][].
Expand Down
159 changes: 142 additions & 17 deletions vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
CLS = 2
STEP = 3
MEAN = 4
VISION = 5
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have created this new type of vision pooling for PoolingClass.



@dataclass(frozen=True)
Expand Down Expand Up @@ -91,6 +92,8 @@

if pooling_type == PoolingType.STEP:
return StepPooler.from_config(resolved_config)
if pooling_type == PoolingType.VISION:
return VisionPooler.from_config(resolved_config)

return SimplePooler.from_config(resolved_config)

Expand Down Expand Up @@ -622,6 +625,86 @@
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]


class VisionPooler(Pooler):

@classmethod
def from_config(cls, model_config: ModelConfig) -> "VisionPooler":
return cls(model_config)

def __init__(self, config: ModelConfig):
super().__init__()
self.config = config

def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
if task == "embed":
return PoolingParams(pooling_type="vision",
logits_processing_needs_token_ids=True)
return None

def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
assert isinstance(pooling_metadata, V1PoolingMetadata)

pooled_outputs = []
for i in range(len(pooling_metadata.prompt_lens)):
start_pos = (pooling_metadata.prompt_token_ids[i] == self.config.
hf_config.vision_start_token_id).nonzero()[-1].item()
end_pos = (pooling_metadata.prompt_token_ids[i] == self.config.
hf_config.vision_end_token_id).nonzero()[-1].item()

seq_start = torch.cumsum(
torch.tensor([0] + pooling_metadata.prompt_lens.tolist()),
dim=0)[i]
seq_len = pooling_metadata.prompt_lens[i]

output = torch.empty(self.config.hidden_size,
device=hidden_states.device,
dtype=hidden_states.dtype)

grid = lambda meta: (self.config.hidden_size, )
mean_pool_with_position_kernel[grid](hidden_states, output,
seq_start, seq_len,
self.config.hidden_size,
start_pos, end_pos + 1)

pooled_outputs.append(output)

return build_output(torch.stack(pooled_outputs))


if HAS_TRITON:

@triton.jit
def mean_pool_with_position_kernel(
hidden_states_ptr,
output_ptr,
seq_start,
seq_len,
hidden_size,
pool_start,
pool_end,
BLOCK_SIZE: tl.constexpr,
):
"""Triton kernel to perform mean pooling over a specified token range."""
pid = tl.program_id(0)

if pid >= hidden_size:
return

accumulator = 0.0
for i in range(pool_start, pool_end):
hidden_val = tl.load(hidden_states_ptr +
(seq_start + i) * hidden_size + pid)
accumulator += hidden_val

# Store mean pooled result
result = accumulator / (pool_end - pool_start)
tl.store(output_ptr + pid, result)


class ClassifierPooler(nn.Module):
"""A pooling layer for classification tasks.

Expand All @@ -640,7 +723,7 @@
pooling: PoolingFn,
classifier: ClassifierFn,
act_fn: Optional[PoolerActivation] = None,
) -> None:

Check failure on line 726 in vllm/model_executor/layers/pooler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/pooler.py:726:81: E501 Line too long (81 > 80)
super().__init__()

self.pooling = pooling
Expand Down Expand Up @@ -709,39 +792,81 @@
return build_output(scores)


class VisionPooler(Pooler):

@classmethod
def from_config(cls, model_config: ModelConfig) -> "VisionPooler":
return cls(model_config)

def __init__(self, config: ModelConfig):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we pass in the token IDs and hidden size explicitly? In case other models store those attributes in different locations

super().__init__()
self.config = config

def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
if task == "embed":
return PoolingParams(pooling_type="vision",
logits_processing_needs_token_ids=True)
return None

def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
assert isinstance(pooling_metadata, V1PoolingMetadata)

pooled_outputs = []
for i in range(len(pooling_metadata.prompt_lens)):
start_pos = (pooling_metadata.prompt_token_ids[i] == self.config.
hf_config.vision_start_token_id).nonzero()[-1].item()
end_pos = (pooling_metadata.prompt_token_ids[i] == self.config.
hf_config.vision_end_token_id).nonzero()[-1].item()

seq_start = torch.cumsum(
torch.tensor([0] + pooling_metadata.prompt_lens.tolist()),
dim=0)[i]
seq_len = pooling_metadata.prompt_lens[i]

output = torch.empty(self.config.hidden_size,
device=hidden_states.device,
dtype=hidden_states.dtype)

grid = lambda meta: (self.config.hidden_size, )
mean_pool_with_position_kernel[grid](hidden_states, output,
seq_start, seq_len,

Check failure on line 836 in vllm/model_executor/layers/pooler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "VisionPooler" already defined on line 663 [no-redef]

Check failure on line 836 in vllm/model_executor/layers/pooler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "VisionPooler" already defined on line 663 [no-redef]

Check failure on line 836 in vllm/model_executor/layers/pooler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "VisionPooler" already defined on line 663 [no-redef]

Check failure on line 836 in vllm/model_executor/layers/pooler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "VisionPooler" already defined on line 663 [no-redef]

Check failure on line 836 in vllm/model_executor/layers/pooler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "VisionPooler" already defined on line 663 [no-redef]

Check failure on line 836 in vllm/model_executor/layers/pooler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Name "VisionPooler" already defined on line 663 [no-redef]

Check failure on line 836 in vllm/model_executor/layers/pooler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F811)

vllm/model_executor/layers/pooler.py:836:7: F811 Redefinition of unused `VisionPooler` from line 663
self.config.hidden_size,
start_pos, end_pos + 1)

pooled_outputs.append(output)

return build_output(torch.stack(pooled_outputs))


if HAS_TRITON:

@triton.jit
def extract_vision_tokens_kernel(
def mean_pool_with_position_kernel(
hidden_states_ptr,
token_ids_ptr,
output_ptr,
seq_start,
seq_len,
hidden_size,
vision_start_id: tl.constexpr,
vision_end_id: tl.constexpr,
pool_start,
pool_end,
BLOCK_SIZE: tl.constexpr,
):
"""Triton kernel to extract and pool vision tokens efficiently."""
"""Triton kernel to perform mean pooling over a specified token range."""
pid = tl.program_id(0)

if pid >= hidden_size:
return

# Find vision token range
vision_count = 0
accumulator = 0.0

for i in range(seq_len):
token_id = tl.load(token_ids_ptr + seq_start + i)
if token_id >= vision_start_id and token_id <= vision_end_id:
hidden_val = tl.load(hidden_states_ptr +
(seq_start + i) * hidden_size + pid)
accumulator += hidden_val
vision_count += 1
for i in range(pool_start, pool_end):
hidden_val = tl.load(hidden_states_ptr +
(seq_start + i) * hidden_size + pid)
accumulator += hidden_val

# Store mean pooled result
result = accumulator / vision_count if vision_count > 0 else 0.0

result = accumulator / (pool_end - pool_start)
tl.store(output_ptr + pid, result)
Loading