@@ -109,6 +109,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
109
109
self .hidden_size = vllm_config .model_config .hf_config .hidden_size
110
110
pooler_config = vllm_config .model_config .pooler_config
111
111
self .observability_config = vllm_config .observability_config
112
+
113
+ # Configuration for vision pooling backend
114
+ self .pooling_backend = getattr (vllm_config .model_config ,
115
+ "jina_pooling_backend" , "triton" )
116
+ if self .pooling_backend not in ("triton" , "pytorch" ):
117
+ logger .warning (
118
+ f"Invalid jina_pooling_backend '{ self .pooling_backend } '. "
119
+ f"Must be 'triton' or 'pytorch'. Defaulting to 'triton'." )
120
+ self .pooling_backend = "triton"
112
121
113
122
# Initialize base pooler for fallback
114
123
self ._base_pooler = Pooler .from_config_with_defaults (
@@ -320,20 +329,15 @@ def pooler(
320
329
logger .error (f"Mismatch: { len (token_ids_list )} sequences vs { len (prompt_lens )} lengths" )
321
330
return self ._base_pooler (hidden_states , pooling_metadata )
322
331
323
- # Apply optimized pooling
324
- try :
332
+ # Apply pooling based on configured backend
333
+ if self . pooling_backend == "triton" :
325
334
pooled_data = self ._apply_vision_pooling_optimized (
326
335
hidden_states , token_ids_list , prompt_lens
327
336
)
328
- except RuntimeError as e :
329
- if "out of memory" in str (e ).lower ():
330
- logger .warning ("OOM during optimized pooling, falling back to batched PyTorch" )
331
- # Fallback to a more memory-efficient PyTorch implementation
332
- pooled_data = self ._apply_vision_pooling_pytorch (
333
- hidden_states , token_ids_list , prompt_lens
334
- )
335
- else :
336
- raise
337
+ else : # self.pooling_backend == "pytorch"
338
+ pooled_data = self ._apply_vision_pooling_pytorch (
339
+ hidden_states , token_ids_list , prompt_lens
340
+ )
337
341
338
342
# Build output
339
343
pooled_outputs = [
0 commit comments