Skip to content

Commit 48c4ea4

Browse files
committed
refactor: oom
1 parent d644c5c commit 48c4ea4

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

vllm/model_executor/models/jina_embeddings_v4.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
109109
self.hidden_size = vllm_config.model_config.hf_config.hidden_size
110110
pooler_config = vllm_config.model_config.pooler_config
111111
self.observability_config = vllm_config.observability_config
112+
113+
# Configuration for vision pooling backend
114+
self.pooling_backend = getattr(vllm_config.model_config,
115+
"jina_pooling_backend", "triton")
116+
if self.pooling_backend not in ("triton", "pytorch"):
117+
logger.warning(
118+
f"Invalid jina_pooling_backend '{self.pooling_backend}'. "
119+
f"Must be 'triton' or 'pytorch'. Defaulting to 'triton'.")
120+
self.pooling_backend = "triton"
112121

113122
# Initialize base pooler for fallback
114123
self._base_pooler = Pooler.from_config_with_defaults(
@@ -320,20 +329,15 @@ def pooler(
320329
logger.error(f"Mismatch: {len(token_ids_list)} sequences vs {len(prompt_lens)} lengths")
321330
return self._base_pooler(hidden_states, pooling_metadata)
322331

323-
# Apply optimized pooling
324-
try:
332+
# Apply pooling based on configured backend
333+
if self.pooling_backend == "triton":
325334
pooled_data = self._apply_vision_pooling_optimized(
326335
hidden_states, token_ids_list, prompt_lens
327336
)
328-
except RuntimeError as e:
329-
if "out of memory" in str(e).lower():
330-
logger.warning("OOM during optimized pooling, falling back to batched PyTorch")
331-
# Fallback to a more memory-efficient PyTorch implementation
332-
pooled_data = self._apply_vision_pooling_pytorch(
333-
hidden_states, token_ids_list, prompt_lens
334-
)
335-
else:
336-
raise
337+
else: # self.pooling_backend == "pytorch"
338+
pooled_data = self._apply_vision_pooling_pytorch(
339+
hidden_states, token_ids_list, prompt_lens
340+
)
337341

338342
# Build output
339343
pooled_outputs = [

0 commit comments

Comments
 (0)