3
3
import time
4
4
from array import array
5
5
from collections .abc import Iterable
6
- from typing import Optional , Tuple , List
6
+ from typing import Optional , Tuple , List , Union
7
7
8
8
import torch
9
9
import torch .nn .functional as F
21
21
from vllm .config import VllmConfig
22
22
from vllm .logger import init_logger
23
23
from vllm .model_executor .layers .pooler import Pooler , PoolingType
24
- from vllm .model_executor .pooling_metadata import PoolingMetadata , PoolingTensors
24
+ from vllm .model_executor .pooling_metadata import (
25
+ PoolingMetadata as V0PoolingMetadata , PoolingTensors )
25
26
from vllm .multimodal import MULTIMODAL_REGISTRY
26
27
from vllm .sequence import IntermediateTensors , PoolerOutput , PoolingSequenceGroupOutput
28
+ from vllm .v1 .pool .metadata import PoolingMetadata as V1PoolingMetadata
27
29
28
30
from .interfaces import SupportsMultiModal , SupportsCrossEncoding
29
31
from .qwen2_vl import (Qwen2VLDummyInputsBuilder ,
41
43
MAX_SEQUENCE_LENGTH = 512 * 1024 # 512K tokens
42
44
43
45
46
+ PoolingMetadata = Union [V0PoolingMetadata , V1PoolingMetadata ]
47
+
48
+
44
49
# Triton kernel for optimized vision token extraction
45
50
if HAS_TRITON :
46
51
@triton .jit
@@ -120,14 +125,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
120
125
logger .info ("Initialized JinaVLForEmbedding with thread-safe pooling" )
121
126
122
127
def _extract_token_ids_safe (
123
- self ,
128
+ self ,
124
129
pooling_metadata : PoolingMetadata
125
130
) -> Tuple [List [array ], List [int ]]:
126
131
"""Safely extract token IDs from pooling metadata."""
132
+ token_ids_list : List [array ] = []
127
133
try :
134
+ if isinstance (pooling_metadata , V1PoolingMetadata ):
135
+ # For V1, we get token IDs and sequence indices directly
136
+ for i , num in enumerate (pooling_metadata .prompt_lens ):
137
+ token_ids = pooling_metadata .prompt_token_ids [i , :num ].tolist ()
138
+ token_ids_list .append (array ('l' , token_ids ))
139
+
140
+ # V1 metadata does not have explicit seq_ids, so we use indices
141
+ seq_ids = list (range (len (token_ids_list )))
142
+ return token_ids_list , seq_ids
143
+
144
+ # For V0, we extract from seq_groups and seq_data
128
145
seq_ids = []
129
- token_ids_list = []
130
-
131
146
for seq_group , _ in pooling_metadata .seq_groups :
132
147
for seq_id in seq_group :
133
148
if seq_id not in pooling_metadata .seq_data :
@@ -151,7 +166,8 @@ def _extract_token_ids_safe(
151
166
return token_ids_list , seq_ids
152
167
153
168
except Exception as e :
154
- logger .error (f"Error extracting token IDs: { e } " )
169
+ logger .error (f"Error extracting token IDs: { e } . "
170
+ f"Extracted { len (token_ids_list )} sequences before failure" )
155
171
raise
156
172
157
173
def _apply_vision_pooling_optimized (
@@ -291,10 +307,13 @@ def pooler(
291
307
# Fallback to base pooler
292
308
return self ._base_pooler (hidden_states , pooling_metadata )
293
309
294
- # Get prompt lengths
295
- prompt_lens = PoolingTensors .from_pooling_metadata (
296
- pooling_metadata , hidden_states .device
297
- ).prompt_lens
310
+ # Get prompt lengths based on metadata type
311
+ if isinstance (pooling_metadata , V1PoolingMetadata ):
312
+ prompt_lens = pooling_metadata .prompt_lens
313
+ else :
314
+ prompt_lens = PoolingTensors .from_pooling_metadata (
315
+ pooling_metadata , hidden_states .device
316
+ ).prompt_lens
298
317
299
318
# Validate lengths match
300
319
if len (token_ids_list ) != len (prompt_lens ):
@@ -308,16 +327,11 @@ def pooler(
308
327
)
309
328
except RuntimeError as e :
310
329
if "out of memory" in str (e ).lower ():
311
- logger .warning ("OOM during pooling, falling back to sequential processing" )
312
- # Process sequences one by one to reduce memory
313
- pooled_data = []
314
- for i in range (len (token_ids_list )):
315
- single_pooled = self ._apply_vision_pooling_pytorch (
316
- hidden_states ,
317
- [token_ids_list [i ]],
318
- prompt_lens [i :i + 1 ]
319
- )
320
- pooled_data .extend (single_pooled )
330
+ logger .warning ("OOM during optimized pooling, falling back to batched PyTorch" )
331
+ # Fallback to a more memory-efficient PyTorch implementation
332
+ pooled_data = self ._apply_vision_pooling_pytorch (
333
+ hidden_states , token_ids_list , prompt_lens
334
+ )
321
335
else :
322
336
raise
323
337
0 commit comments