Skip to content

Commit 6b31c66

Browse files
author
Sigrid Jin (Sionic AI)
committed
refactor: revert to range checks for vision token detection
Revert the torch.isin optimization as pointed out by DarkLight1337. The torch.isin approach only matches tokens that are exactly vision_start_id or vision_end_id, but we need to match ALL tokens in the range [vision_start_id, vision_end_id]. The range check correctly handles: - All tokens within the range (not just endpoints) - Future expansion of the vision token range - Proper semantic intent of the code Keep the get_pooling_params consolidation as that change is correct. Signed-off-by: Sigrid Jin (Sionic AI) <sigrid@sionic.ai>
1 parent 702fd16 commit 6b31c66

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

vllm/model_executor/models/jina_embeddings_v4.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,6 @@ def __init__(self,
5050
self.pooling_backend = pooling_backend
5151
self.observability_config = vllm_config.observability_config
5252

53-
# Pre-compute vision token IDs tensor for efficient checking
54-
self.vision_token_ids = torch.tensor(
55-
[VISION_START_TOKEN_ID, VISION_END_TOKEN_ID], dtype=torch.long)
56-
5753
# Performance tracking
5854
self._pooling_time_ms = 0.0
5955
self._pooling_count = 0
@@ -207,8 +203,8 @@ def _apply_vision_pooling_optimized(
207203
dtype=hidden_states.dtype)
208204

209205
# Check for vision tokens
210-
has_vision = torch.isin(token_tensor,
211-
self.vision_token_ids.to(device)).any()
206+
has_vision = torch.any((token_tensor >= VISION_START_TOKEN_ID)
207+
& (token_tensor <= VISION_END_TOKEN_ID))
212208

213209
if has_vision:
214210
# Use Triton kernel for vision token extraction
@@ -259,8 +255,8 @@ def _apply_vision_pooling_pytorch(
259255
device=hidden_states.device)
260256

261257
# Check for vision tokens
262-
vision_mask = torch.isin(
263-
seq_tokens, self.vision_token_ids.to(seq_tokens.device))
258+
vision_mask = ((seq_tokens >= VISION_START_TOKEN_ID) &
259+
(seq_tokens <= VISION_END_TOKEN_ID))
264260

265261
if vision_mask.any():
266262
# Pool only vision tokens

0 commit comments

Comments
 (0)