Skip to content

Commit d644c5c

Browse files
committed
refactor: improve jina embeddings v4 model
1 parent 1b594aa commit d644c5c

File tree

1 file changed

+34
-20
lines changed

1 file changed

+34
-20
lines changed

vllm/model_executor/models/jina_embeddings_v4.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import time
44
from array import array
55
from collections.abc import Iterable
6-
from typing import Optional, Tuple, List
6+
from typing import Optional, Tuple, List, Union
77

88
import torch
99
import torch.nn.functional as F
@@ -21,9 +21,11 @@
2121
from vllm.config import VllmConfig
2222
from vllm.logger import init_logger
2323
from vllm.model_executor.layers.pooler import Pooler, PoolingType
24-
from vllm.model_executor.pooling_metadata import PoolingMetadata, PoolingTensors
24+
from vllm.model_executor.pooling_metadata import (
25+
PoolingMetadata as V0PoolingMetadata, PoolingTensors)
2526
from vllm.multimodal import MULTIMODAL_REGISTRY
2627
from vllm.sequence import IntermediateTensors, PoolerOutput, PoolingSequenceGroupOutput
28+
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
2729

2830
from .interfaces import SupportsMultiModal, SupportsCrossEncoding
2931
from .qwen2_vl import (Qwen2VLDummyInputsBuilder,
@@ -41,6 +43,9 @@
4143
MAX_SEQUENCE_LENGTH = 512 * 1024 # 512K tokens
4244

4345

46+
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
47+
48+
4449
# Triton kernel for optimized vision token extraction
4550
if HAS_TRITON:
4651
@triton.jit
@@ -120,14 +125,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
120125
logger.info("Initialized JinaVLForEmbedding with thread-safe pooling")
121126

122127
def _extract_token_ids_safe(
123-
self,
128+
self,
124129
pooling_metadata: PoolingMetadata
125130
) -> Tuple[List[array], List[int]]:
126131
"""Safely extract token IDs from pooling metadata."""
132+
token_ids_list: List[array] = []
127133
try:
134+
if isinstance(pooling_metadata, V1PoolingMetadata):
135+
# For V1, we get token IDs and sequence indices directly
136+
for i, num in enumerate(pooling_metadata.prompt_lens):
137+
token_ids = pooling_metadata.prompt_token_ids[i, :num].tolist()
138+
token_ids_list.append(array('l', token_ids))
139+
140+
# V1 metadata does not have explicit seq_ids, so we use indices
141+
seq_ids = list(range(len(token_ids_list)))
142+
return token_ids_list, seq_ids
143+
144+
# For V0, we extract from seq_groups and seq_data
128145
seq_ids = []
129-
token_ids_list = []
130-
131146
for seq_group, _ in pooling_metadata.seq_groups:
132147
for seq_id in seq_group:
133148
if seq_id not in pooling_metadata.seq_data:
@@ -151,7 +166,8 @@ def _extract_token_ids_safe(
151166
return token_ids_list, seq_ids
152167

153168
except Exception as e:
154-
logger.error(f"Error extracting token IDs: {e}")
169+
logger.error(f"Error extracting token IDs: {e}. "
170+
f"Extracted {len(token_ids_list)} sequences before failure")
155171
raise
156172

157173
def _apply_vision_pooling_optimized(
@@ -291,10 +307,13 @@ def pooler(
291307
# Fallback to base pooler
292308
return self._base_pooler(hidden_states, pooling_metadata)
293309

294-
# Get prompt lengths
295-
prompt_lens = PoolingTensors.from_pooling_metadata(
296-
pooling_metadata, hidden_states.device
297-
).prompt_lens
310+
# Get prompt lengths based on metadata type
311+
if isinstance(pooling_metadata, V1PoolingMetadata):
312+
prompt_lens = pooling_metadata.prompt_lens
313+
else:
314+
prompt_lens = PoolingTensors.from_pooling_metadata(
315+
pooling_metadata, hidden_states.device
316+
).prompt_lens
298317

299318
# Validate lengths match
300319
if len(token_ids_list) != len(prompt_lens):
@@ -308,16 +327,11 @@ def pooler(
308327
)
309328
except RuntimeError as e:
310329
if "out of memory" in str(e).lower():
311-
logger.warning("OOM during pooling, falling back to sequential processing")
312-
# Process sequences one by one to reduce memory
313-
pooled_data = []
314-
for i in range(len(token_ids_list)):
315-
single_pooled = self._apply_vision_pooling_pytorch(
316-
hidden_states,
317-
[token_ids_list[i]],
318-
prompt_lens[i:i+1]
319-
)
320-
pooled_data.extend(single_pooled)
330+
logger.warning("OOM during optimized pooling, falling back to batched PyTorch")
331+
# Fallback to a more memory-efficient PyTorch implementation
332+
pooled_data = self._apply_vision_pooling_pytorch(
333+
hidden_states, token_ids_list, prompt_lens
334+
)
321335
else:
322336
raise
323337

0 commit comments

Comments
 (0)