Skip to content

Commit 1b4f405

Browse files
author
Sigrid Jin (Sionic AI)
committed
refactor: address maintainer review comments for JinaVLPooler
Address DarkLight1337's review feedback: - Set logits_processing_needs_token_ids=True for V1 compatibility in both "embed" and "encode" tasks - Support "encode" task by returning PoolingParams() instead of None - Update log message from "thread-safe pooling" to "vision-aware pooling" to better reflect the actual functionality - Remove unused seq_ids variable from _extract_token_ids_safe method These changes ensure proper V1 compatibility and cleaner code structure. Signed-off-by: Sigrid Jin (Sionic AI) <sigrid@sionic.ai>
1 parent eb1497e commit 1b4f405

File tree

1 file changed

+11
-14
lines changed

1 file changed

+11
-14
lines changed

vllm/model_executor/models/jina_embeddings_v4.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,13 @@ def __init__(self,
6565
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
6666
"""Return pooling params for embedding task."""
6767
if task == "embed":
68-
return PoolingParams()
68+
return PoolingParams(logits_processing_needs_token_ids=True)
69+
70+
if task == "encode":
71+
return PoolingParams(logits_processing_needs_token_ids=True)
6972

7073
# The equalities are split up to keep mypy happy
71-
if task == "encode" or task == "classify" or task == "score":
74+
if task == "classify" or task == "score":
7275
return None
7376

7477
assert_never(task)
@@ -87,8 +90,7 @@ def forward(
8790
return build_output(torch.empty((0, 0)))
8891

8992
# Extract token IDs safely from metadata
90-
token_ids_list, seq_ids = self._extract_token_ids_safe(
91-
pooling_metadata)
93+
token_ids_list = self._extract_token_ids_safe(pooling_metadata)
9294

9395
if not token_ids_list:
9496
logger.warning("No valid sequences found for pooling")
@@ -127,24 +129,20 @@ def forward(
127129
return build_output(pooled_tensor)
128130

129131
def _extract_token_ids_safe(
130-
self, pooling_metadata: PoolingMetadata
131-
) -> tuple[list[array], list[int]]:
132+
self, pooling_metadata: PoolingMetadata) -> list[array]:
132133
"""Safely extract token IDs from pooling metadata."""
133134
token_ids_list: list[array] = []
134135
try:
135136
if isinstance(pooling_metadata, V1PoolingMetadata):
136-
# For V1, we get token IDs and sequence indices directly
137+
# For V1, we get token IDs directly
137138
for i, num in enumerate(pooling_metadata.prompt_lens):
138139
token_ids = pooling_metadata.prompt_token_ids[
139140
i, :num].tolist()
140141
token_ids_list.append(array('l', token_ids))
141142

142-
# V1 metadata does not have explicit seq_ids, so we use indices
143-
seq_ids = list(range(len(token_ids_list)))
144-
return token_ids_list, seq_ids
143+
return token_ids_list
145144

146145
# For V0, we extract from seq_groups and seq_data
147-
seq_ids = []
148146
for seq_group, _ in pooling_metadata.seq_groups:
149147
for seq_id in seq_group:
150148
if seq_id not in pooling_metadata.seq_data:
@@ -164,10 +162,9 @@ def _extract_token_ids_safe(
164162
seq_id)
165163
continue
166164

167-
seq_ids.append(seq_id)
168165
token_ids_list.append(token_ids)
169166

170-
return token_ids_list, seq_ids
167+
return token_ids_list
171168

172169
except Exception as e:
173170
logger.error(
@@ -319,7 +316,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
319316
# Initialize the vision-aware pooler
320317
self.pooler = JinaVLPooler(vllm_config, self.pooling_backend)
321318

322-
logger.info("Initialized JinaVLForEmbedding with thread-safe pooling")
319+
logger.info("Initialized JinaVLForEmbedding with vision-aware pooling")
323320

324321
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
325322
"""Load weights with validation and error handling."""

0 commit comments

Comments
 (0)