Skip to content

Commit c50992a

Browse files
author
Sigrid Jin (Sionic AI)
committed
feat: add vision pooling support for jina embeddings v4
Signed-off-by: Sigrid Jin (Sionic AI) <sigrid@sionic.ai>
1 parent 5114a3c commit c50992a

File tree

3 files changed

+17
-88
lines changed

3 files changed

+17
-88
lines changed

vllm/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3256,7 +3256,7 @@ def get_limit_per_prompt(self, modality: str) -> int:
32563256
@config
32573257
@dataclass
32583258
class PoolerConfig:
3259-
"""Configuration for the pooler."""
3259+
"""Controls the behavior of output pooling in pooling models."""
32603260

32613261
pooling_type: Optional[Literal["last", "all", "cls", "step", "mean",
32623262
"vision"]] = None

vllm/model_executor/layers/pooler.py

Lines changed: 15 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -625,56 +625,6 @@ def forward(
625625
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
626626

627627

628-
class VisionPooler(Pooler):
629-
630-
@classmethod
631-
def from_config(cls, model_config: ModelConfig) -> "VisionPooler":
632-
return cls(model_config)
633-
634-
def __init__(self, config: ModelConfig):
635-
super().__init__()
636-
self.config = config
637-
638-
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
639-
if task == "embed":
640-
return PoolingParams(pooling_type="vision",
641-
logits_processing_needs_token_ids=True)
642-
return None
643-
644-
def forward(
645-
self,
646-
hidden_states: torch.Tensor,
647-
pooling_metadata: PoolingMetadata,
648-
) -> PoolerOutput:
649-
assert isinstance(pooling_metadata, V1PoolingMetadata)
650-
651-
pooled_outputs = []
652-
for i in range(len(pooling_metadata.prompt_lens)):
653-
start_pos = (pooling_metadata.prompt_token_ids[i] == self.config.
654-
hf_config.vision_start_token_id).nonzero()[-1].item()
655-
end_pos = (pooling_metadata.prompt_token_ids[i] == self.config.
656-
hf_config.vision_end_token_id).nonzero()[-1].item()
657-
658-
seq_start = torch.cumsum(
659-
torch.tensor([0] + pooling_metadata.prompt_lens.tolist()),
660-
dim=0)[i]
661-
seq_len = pooling_metadata.prompt_lens[i]
662-
663-
output = torch.empty(self.config.hidden_size,
664-
device=hidden_states.device,
665-
dtype=hidden_states.dtype)
666-
667-
grid = lambda meta: (self.config.hidden_size, )
668-
mean_pool_with_position_kernel[grid](hidden_states, output,
669-
seq_start, seq_len,
670-
self.config.hidden_size,
671-
start_pos, end_pos + 1)
672-
673-
pooled_outputs.append(output)
674-
675-
return build_output(torch.stack(pooled_outputs))
676-
677-
678628
if HAS_TRITON:
679629

680630
@triton.jit
@@ -817,10 +767,12 @@ def forward(
817767

818768
pooled_outputs = []
819769
for i in range(len(pooling_metadata.prompt_lens)):
820-
start_pos = (pooling_metadata.prompt_token_ids[i] == self.config.
821-
hf_config.vision_start_token_id).nonzero()[-1].item()
822-
end_pos = (pooling_metadata.prompt_token_ids[i] == self.config.
823-
hf_config.vision_end_token_id).nonzero()[-1].item()
770+
start_pos = (pooling_metadata.prompt_token_ids[i] ==
771+
self.config.hf_config.vision_start_token_id).
772+
nonzero()[-1].item()
773+
end_pos = (pooling_metadata.prompt_token_ids[i] ==
774+
self.config.hf_config.vision_end_token_id).
775+
nonzero()[-1].item()
824776

825777
seq_start = torch.cumsum(
826778
torch.tensor([0] + pooling_metadata.prompt_lens.tolist()),
@@ -832,41 +784,18 @@ def forward(
832784
dtype=hidden_states.dtype)
833785

834786
grid = lambda meta: (self.config.hidden_size, )
835-
mean_pool_with_position_kernel[grid](hidden_states, output,
836-
seq_start, seq_len,
837-
self.config.hidden_size,
838-
start_pos, end_pos + 1)
787+
if HAS_TRITON:
788+
mean_pool_with_position_kernel[grid](hidden_states, output,
789+
seq_start, seq_len,
790+
self.config.hidden_size,
791+
start_pos, end_pos + 1)
792+
else:
793+
# Fallback to PyTorch implementation if Triton is not available
794+
vision_tokens_range = hidden_states[seq_start + start_pos : seq_start + end_pos + 1]
795+
output = vision_tokens_range.mean(dim=0)
839796

840797
pooled_outputs.append(output)
841798

842799
return build_output(torch.stack(pooled_outputs))
843800

844801

845-
if HAS_TRITON:
846-
847-
@triton.jit
848-
def mean_pool_with_position_kernel(
849-
hidden_states_ptr,
850-
output_ptr,
851-
seq_start,
852-
seq_len,
853-
hidden_size,
854-
pool_start,
855-
pool_end,
856-
BLOCK_SIZE: tl.constexpr,
857-
):
858-
"""Triton kernel to perform mean pooling over a specified token range."""
859-
pid = tl.program_id(0)
860-
861-
if pid >= hidden_size:
862-
return
863-
864-
accumulator = 0.0
865-
for i in range(pool_start, pool_end):
866-
hidden_val = tl.load(hidden_states_ptr +
867-
(seq_start + i) * hidden_size + pid)
868-
accumulator += hidden_val
869-
870-
# Store mean pooled result
871-
result = accumulator / (pool_end - pool_start)
872-
tl.store(output_ptr + pid, result)

vllm/model_executor/models/jina_embeddings_v4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from vllm.config import VllmConfig
88
from vllm.logger import init_logger
9-
from vllm.model_executor.layers.pooler import Pooler, PoolingTask
9+
from vllm.model_executor.layers.pooler import Pooler, PoolingTask, VisionPooler
1010
# yapf: disable
1111
from vllm.model_executor.pooling_metadata import (
1212
PoolingMetadata as V0PoolingMetadata)

0 commit comments

Comments
 (0)