Skip to content

Commit 5114a3c

Browse files
author
Sigrid Jin (Sionic AI)
committed
fix: introducing dedicated VisionPooler class
Signed-off-by: Sigrid Jin (Sionic AI) <sigrid@sionic.ai>
1 parent 702fd16 commit 5114a3c

File tree

4 files changed

+189
-287
lines changed

4 files changed

+189
-287
lines changed

tests/models/pooling/test_jina_embeddings_v4.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,3 +342,40 @@ def test_vision_only_pooling(self, model):
342342
# embeddings should be very similar despite different text
343343
similarity = torch.dot(emb1, emb2).item()
344344
assert similarity > 0.99 # Should be nearly identical
345+
346+
347+
class TestVisionPooler:
348+
"""Test the VisionPooler class."""
349+
350+
def test_vision_pooler(self):
351+
"""Test that the VisionPooler correctly pools vision tokens."""
352+
from vllm.config import ModelConfig
353+
from vllm.model_executor.layers.pooler import VisionPooler
354+
from vllm.pooling_params import PoolingParams
355+
from vllm.v1.pool.metadata import PoolingMetadata
356+
357+
model_config = ModelConfig(model_name, task="embed")
358+
model_config.hf_config.vision_start_token_id = VISION_START_TOKEN_ID
359+
model_config.hf_config.vision_end_token_id = VISION_END_TOKEN_ID
360+
model_config.hidden_size = 4
361+
362+
pooler = VisionPooler(model_config)
363+
364+
hidden_states = torch.randn(10, 4)
365+
prompt_token_ids = torch.tensor([[
366+
1, 2, VISION_START_TOKEN_ID, 4, VISION_END_TOKEN_ID, 6, 7, 8, 9, 10
367+
]])
368+
prompt_lens = torch.tensor([10])
369+
370+
pooling_metadata = PoolingMetadata(prompt_lens=prompt_lens,
371+
prompt_token_ids=prompt_token_ids,
372+
pooling_params=[PoolingParams()])
373+
374+
output = pooler.forward(hidden_states, pooling_metadata)
375+
376+
vision_tokens = hidden_states[2:5]
377+
expected_output = vision_tokens.mean(dim=0)
378+
379+
assert torch.allclose(output.outputs[0].data,
380+
expected_output,
381+
atol=1e-5)

vllm/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3256,9 +3256,10 @@ def get_limit_per_prompt(self, modality: str) -> int:
32563256
@config
32573257
@dataclass
32583258
class PoolerConfig:
3259-
"""Controls the behavior of output pooling in pooling models."""
3259+
"""Configuration for the pooler."""
32603260

3261-
pooling_type: Optional[str] = None
3261+
pooling_type: Optional[Literal["last", "all", "cls", "step", "mean",
3262+
"vision"]] = None
32623263
"""
32633264
The pooling method of the pooling model. This should be a key in
32643265
[`vllm.model_executor.layers.pooler.PoolingType`][].

vllm/model_executor/layers/pooler.py

Lines changed: 142 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class PoolingType(IntEnum):
3232
CLS = 2
3333
STEP = 3
3434
MEAN = 4
35+
VISION = 5
3536

3637

3738
@dataclass(frozen=True)
@@ -91,6 +92,8 @@ def from_config_with_defaults(
9192

9293
if pooling_type == PoolingType.STEP:
9394
return StepPooler.from_config(resolved_config)
95+
if pooling_type == PoolingType.VISION:
96+
return VisionPooler.from_config(resolved_config)
9497

9598
return SimplePooler.from_config(resolved_config)
9699

@@ -622,6 +625,86 @@ def forward(
622625
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
623626

624627

628+
class VisionPooler(Pooler):
629+
630+
@classmethod
631+
def from_config(cls, model_config: ModelConfig) -> "VisionPooler":
632+
return cls(model_config)
633+
634+
def __init__(self, config: ModelConfig):
635+
super().__init__()
636+
self.config = config
637+
638+
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
639+
if task == "embed":
640+
return PoolingParams(pooling_type="vision",
641+
logits_processing_needs_token_ids=True)
642+
return None
643+
644+
def forward(
645+
self,
646+
hidden_states: torch.Tensor,
647+
pooling_metadata: PoolingMetadata,
648+
) -> PoolerOutput:
649+
assert isinstance(pooling_metadata, V1PoolingMetadata)
650+
651+
pooled_outputs = []
652+
for i in range(len(pooling_metadata.prompt_lens)):
653+
start_pos = (pooling_metadata.prompt_token_ids[i] == self.config.
654+
hf_config.vision_start_token_id).nonzero()[-1].item()
655+
end_pos = (pooling_metadata.prompt_token_ids[i] == self.config.
656+
hf_config.vision_end_token_id).nonzero()[-1].item()
657+
658+
seq_start = torch.cumsum(
659+
torch.tensor([0] + pooling_metadata.prompt_lens.tolist()),
660+
dim=0)[i]
661+
seq_len = pooling_metadata.prompt_lens[i]
662+
663+
output = torch.empty(self.config.hidden_size,
664+
device=hidden_states.device,
665+
dtype=hidden_states.dtype)
666+
667+
grid = lambda meta: (self.config.hidden_size, )
668+
mean_pool_with_position_kernel[grid](hidden_states, output,
669+
seq_start, seq_len,
670+
self.config.hidden_size,
671+
start_pos, end_pos + 1)
672+
673+
pooled_outputs.append(output)
674+
675+
return build_output(torch.stack(pooled_outputs))
676+
677+
678+
if HAS_TRITON:
679+
680+
@triton.jit
681+
def mean_pool_with_position_kernel(
682+
hidden_states_ptr,
683+
output_ptr,
684+
seq_start,
685+
seq_len,
686+
hidden_size,
687+
pool_start,
688+
pool_end,
689+
BLOCK_SIZE: tl.constexpr,
690+
):
691+
"""Triton kernel to perform mean pooling over a specified token range."""
692+
pid = tl.program_id(0)
693+
694+
if pid >= hidden_size:
695+
return
696+
697+
accumulator = 0.0
698+
for i in range(pool_start, pool_end):
699+
hidden_val = tl.load(hidden_states_ptr +
700+
(seq_start + i) * hidden_size + pid)
701+
accumulator += hidden_val
702+
703+
# Store mean pooled result
704+
result = accumulator / (pool_end - pool_start)
705+
tl.store(output_ptr + pid, result)
706+
707+
625708
class ClassifierPooler(nn.Module):
626709
"""A pooling layer for classification tasks.
627710
@@ -709,39 +792,81 @@ def forward(
709792
return build_output(scores)
710793

711794

795+
class VisionPooler(Pooler):
796+
797+
@classmethod
798+
def from_config(cls, model_config: ModelConfig) -> "VisionPooler":
799+
return cls(model_config)
800+
801+
def __init__(self, config: ModelConfig):
802+
super().__init__()
803+
self.config = config
804+
805+
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
806+
if task == "embed":
807+
return PoolingParams(pooling_type="vision",
808+
logits_processing_needs_token_ids=True)
809+
return None
810+
811+
def forward(
812+
self,
813+
hidden_states: torch.Tensor,
814+
pooling_metadata: PoolingMetadata,
815+
) -> PoolerOutput:
816+
assert isinstance(pooling_metadata, V1PoolingMetadata)
817+
818+
pooled_outputs = []
819+
for i in range(len(pooling_metadata.prompt_lens)):
820+
start_pos = (pooling_metadata.prompt_token_ids[i] == self.config.
821+
hf_config.vision_start_token_id).nonzero()[-1].item()
822+
end_pos = (pooling_metadata.prompt_token_ids[i] == self.config.
823+
hf_config.vision_end_token_id).nonzero()[-1].item()
824+
825+
seq_start = torch.cumsum(
826+
torch.tensor([0] + pooling_metadata.prompt_lens.tolist()),
827+
dim=0)[i]
828+
seq_len = pooling_metadata.prompt_lens[i]
829+
830+
output = torch.empty(self.config.hidden_size,
831+
device=hidden_states.device,
832+
dtype=hidden_states.dtype)
833+
834+
grid = lambda meta: (self.config.hidden_size, )
835+
mean_pool_with_position_kernel[grid](hidden_states, output,
836+
seq_start, seq_len,
837+
self.config.hidden_size,
838+
start_pos, end_pos + 1)
839+
840+
pooled_outputs.append(output)
841+
842+
return build_output(torch.stack(pooled_outputs))
843+
844+
712845
if HAS_TRITON:
713846

714847
@triton.jit
715-
def extract_vision_tokens_kernel(
848+
def mean_pool_with_position_kernel(
716849
hidden_states_ptr,
717-
token_ids_ptr,
718850
output_ptr,
719851
seq_start,
720852
seq_len,
721853
hidden_size,
722-
vision_start_id: tl.constexpr,
723-
vision_end_id: tl.constexpr,
854+
pool_start,
855+
pool_end,
724856
BLOCK_SIZE: tl.constexpr,
725857
):
726-
"""Triton kernel to extract and pool vision tokens efficiently."""
858+
"""Triton kernel to perform mean pooling over a specified token range."""
727859
pid = tl.program_id(0)
728860

729861
if pid >= hidden_size:
730862
return
731863

732-
# Find vision token range
733-
vision_count = 0
734864
accumulator = 0.0
735-
736-
for i in range(seq_len):
737-
token_id = tl.load(token_ids_ptr + seq_start + i)
738-
if token_id >= vision_start_id and token_id <= vision_end_id:
739-
hidden_val = tl.load(hidden_states_ptr +
740-
(seq_start + i) * hidden_size + pid)
741-
accumulator += hidden_val
742-
vision_count += 1
865+
for i in range(pool_start, pool_end):
866+
hidden_val = tl.load(hidden_states_ptr +
867+
(seq_start + i) * hidden_size + pid)
868+
accumulator += hidden_val
743869

744870
# Store mean pooled result
745-
result = accumulator / vision_count if vision_count > 0 else 0.0
746-
871+
result = accumulator / (pool_end - pool_start)
747872
tl.store(output_ptr + pid, result)

0 commit comments

Comments
 (0)