Skip to content

Commit 702fd16

Browse files
author
Sigrid Jin (Sionic AI)
committed
perf: optimize vision token detection using torch.isin
Implement efficiency improvements suggested by DarkLight1337: - Consolidate get_pooling_params method for "embed" and "encode" tasks - Pre-compute vision token IDs tensor in constructor - Replace range checks with torch.isin for more efficient vision token detection at lines 209-210 and 261-262 This reduces redundant code and improves performance when checking for vision tokens by using optimized tensor operations. Signed-off-by: Sigrid Jin (Sionic AI) <sigrid@sionic.ai>
1 parent 1b4f405 commit 702fd16

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

vllm/model_executor/models/jina_embeddings_v4.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ def __init__(self,
5050
self.pooling_backend = pooling_backend
5151
self.observability_config = vllm_config.observability_config
5252

53+
# Pre-compute vision token IDs tensor for efficient checking
54+
self.vision_token_ids = torch.tensor(
55+
[VISION_START_TOKEN_ID, VISION_END_TOKEN_ID], dtype=torch.long)
56+
5357
# Performance tracking
5458
self._pooling_time_ms = 0.0
5559
self._pooling_count = 0
@@ -64,10 +68,7 @@ def __init__(self,
6468

6569
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
6670
"""Return pooling params for embedding task."""
67-
if task == "embed":
68-
return PoolingParams(logits_processing_needs_token_ids=True)
69-
70-
if task == "encode":
71+
if task == "embed" or task == "encode":
7172
return PoolingParams(logits_processing_needs_token_ids=True)
7273

7374
# The equalities are split up to keep mypy happy
@@ -206,8 +207,8 @@ def _apply_vision_pooling_optimized(
206207
dtype=hidden_states.dtype)
207208

208209
# Check for vision tokens
209-
has_vision = torch.any((token_tensor >= VISION_START_TOKEN_ID)
210-
& (token_tensor <= VISION_END_TOKEN_ID))
210+
has_vision = torch.isin(token_tensor,
211+
self.vision_token_ids.to(device)).any()
211212

212213
if has_vision:
213214
# Use Triton kernel for vision token extraction
@@ -258,8 +259,8 @@ def _apply_vision_pooling_pytorch(
258259
device=hidden_states.device)
259260

260261
# Check for vision tokens
261-
vision_mask = ((seq_tokens >= VISION_START_TOKEN_ID) &
262-
(seq_tokens <= VISION_END_TOKEN_ID))
262+
vision_mask = torch.isin(
263+
seq_tokens, self.vision_token_ids.to(seq_tokens.device))
263264

264265
if vision_mask.any():
265266
# Pool only vision tokens

0 commit comments

Comments
 (0)