Skip to content

Commit 5b67b13

Browse files
committed
refactor: fail fast
1 parent d7d6b60 commit 5b67b13

File tree

1 file changed

+56
-69
lines changed

1 file changed

+56
-69
lines changed

vllm/model_executor/models/jina_embeddings_v4.py

Lines changed: 56 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -278,78 +278,65 @@ def pooler(
278278
"""Thread-safe pooler with production error handling."""
279279
start_time = time.time() if self.observability_config else None
280280

281+
# Validate inputs
282+
if hidden_states is None or hidden_states.numel() == 0:
283+
logger.warning("Empty hidden states received")
284+
return PoolerOutput(outputs=[])
285+
286+
# Extract token IDs safely from metadata
287+
token_ids_list, seq_ids = self._extract_token_ids_safe(pooling_metadata)
288+
289+
if not token_ids_list:
290+
logger.warning("No valid sequences found for pooling")
291+
# Fallback to base pooler
292+
return self._base_pooler(hidden_states, pooling_metadata)
293+
294+
# Get prompt lengths
295+
prompt_lens = PoolingTensors.from_pooling_metadata(
296+
pooling_metadata, hidden_states.device
297+
).prompt_lens
298+
299+
# Validate lengths match
300+
if len(token_ids_list) != len(prompt_lens):
301+
logger.error(f"Mismatch: {len(token_ids_list)} sequences vs {len(prompt_lens)} lengths")
302+
return self._base_pooler(hidden_states, pooling_metadata)
303+
304+
# Apply optimized pooling
281305
try:
282-
# Validate inputs
283-
if hidden_states is None or hidden_states.numel() == 0:
284-
logger.warning("Empty hidden states received")
285-
return PoolerOutput(outputs=[])
286-
287-
# Extract token IDs safely from metadata
288-
token_ids_list, seq_ids = self._extract_token_ids_safe(pooling_metadata)
289-
290-
if not token_ids_list:
291-
logger.warning("No valid sequences found for pooling")
292-
# Fallback to base pooler
293-
return self._base_pooler(hidden_states, pooling_metadata)
294-
295-
# Get prompt lengths
296-
prompt_lens = PoolingTensors.from_pooling_metadata(
297-
pooling_metadata, hidden_states.device
298-
).prompt_lens
299-
300-
# Validate lengths match
301-
if len(token_ids_list) != len(prompt_lens):
302-
logger.error(f"Mismatch: {len(token_ids_list)} sequences vs {len(prompt_lens)} lengths")
303-
return self._base_pooler(hidden_states, pooling_metadata)
304-
305-
# Apply optimized pooling
306-
try:
307-
pooled_data = self._apply_vision_pooling_optimized(
308-
hidden_states, token_ids_list, prompt_lens
309-
)
310-
except RuntimeError as e:
311-
if "out of memory" in str(e).lower():
312-
logger.warning("OOM during pooling, falling back to sequential processing")
313-
# Process sequences one by one to reduce memory
314-
pooled_data = []
315-
for i in range(len(token_ids_list)):
316-
single_pooled = self._apply_vision_pooling_pytorch(
317-
hidden_states,
318-
[token_ids_list[i]],
319-
prompt_lens[i:i+1]
320-
)
321-
pooled_data.extend(single_pooled)
322-
else:
323-
raise
324-
325-
# Build output
326-
pooled_outputs = [
327-
PoolingSequenceGroupOutput(data) for data in pooled_data
328-
]
329-
330-
# Record metrics
331-
if self.observability_config:
332-
elapsed_ms = (time.time() - start_time) * 1000
333-
self._pooling_time_ms += elapsed_ms
334-
self._pooling_count += 1
335-
336-
if self._pooling_count % 100 == 0:
337-
avg_time = self._pooling_time_ms / self._pooling_count
338-
logger.debug(f"Average pooling time: {avg_time:.2f}ms")
339-
340-
return PoolerOutput(outputs=pooled_outputs)
306+
pooled_data = self._apply_vision_pooling_optimized(
307+
hidden_states, token_ids_list, prompt_lens
308+
)
309+
except RuntimeError as e:
310+
if "out of memory" in str(e).lower():
311+
logger.warning("OOM during pooling, falling back to sequential processing")
312+
# Process sequences one by one to reduce memory
313+
pooled_data = []
314+
for i in range(len(token_ids_list)):
315+
single_pooled = self._apply_vision_pooling_pytorch(
316+
hidden_states,
317+
[token_ids_list[i]],
318+
prompt_lens[i:i+1]
319+
)
320+
pooled_data.extend(single_pooled)
321+
else:
322+
raise
323+
324+
# Build output
325+
pooled_outputs = [
326+
PoolingSequenceGroupOutput(data) for data in pooled_data
327+
]
328+
329+
# Record metrics
330+
if self.observability_config:
331+
elapsed_ms = (time.time() - start_time) * 1000
332+
self._pooling_time_ms += elapsed_ms
333+
self._pooling_count += 1
341334

342-
except Exception as e:
343-
logger.error(f"Error in pooler: {type(e).__name__}: {e}")
344-
# Graceful degradation to base pooler
345-
logger.info("Falling back to base pooler due to error")
346-
return self._base_pooler(hidden_states, pooling_metadata)
335+
if self._pooling_count % 100 == 0:
336+
avg_time = self._pooling_time_ms / self._pooling_count
337+
logger.debug(f"Average pooling time: {avg_time:.2f}ms")
347338

348-
finally:
349-
# Rely on Python's garbage collector for releasing tensors.
350-
# torch.cuda.empty_cache() is a blocking and expensive operation
351-
# that should be used sparingly.
352-
pass
339+
return PoolerOutput(outputs=pooled_outputs)
353340

354341
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
355342
"""Load weights with validation and error handling."""

0 commit comments

Comments
 (0)