@@ -278,78 +278,65 @@ def pooler(
278
278
"""Thread-safe pooler with production error handling."""
279
279
start_time = time .time () if self .observability_config else None
280
280
281
+ # Validate inputs
282
+ if hidden_states is None or hidden_states .numel () == 0 :
283
+ logger .warning ("Empty hidden states received" )
284
+ return PoolerOutput (outputs = [])
285
+
286
+ # Extract token IDs safely from metadata
287
+ token_ids_list , seq_ids = self ._extract_token_ids_safe (pooling_metadata )
288
+
289
+ if not token_ids_list :
290
+ logger .warning ("No valid sequences found for pooling" )
291
+ # Fallback to base pooler
292
+ return self ._base_pooler (hidden_states , pooling_metadata )
293
+
294
+ # Get prompt lengths
295
+ prompt_lens = PoolingTensors .from_pooling_metadata (
296
+ pooling_metadata , hidden_states .device
297
+ ).prompt_lens
298
+
299
+ # Validate lengths match
300
+ if len (token_ids_list ) != len (prompt_lens ):
301
+ logger .error (f"Mismatch: { len (token_ids_list )} sequences vs { len (prompt_lens )} lengths" )
302
+ return self ._base_pooler (hidden_states , pooling_metadata )
303
+
304
+ # Apply optimized pooling
281
305
try :
282
- # Validate inputs
283
- if hidden_states is None or hidden_states .numel () == 0 :
284
- logger .warning ("Empty hidden states received" )
285
- return PoolerOutput (outputs = [])
286
-
287
- # Extract token IDs safely from metadata
288
- token_ids_list , seq_ids = self ._extract_token_ids_safe (pooling_metadata )
289
-
290
- if not token_ids_list :
291
- logger .warning ("No valid sequences found for pooling" )
292
- # Fallback to base pooler
293
- return self ._base_pooler (hidden_states , pooling_metadata )
294
-
295
- # Get prompt lengths
296
- prompt_lens = PoolingTensors .from_pooling_metadata (
297
- pooling_metadata , hidden_states .device
298
- ).prompt_lens
299
-
300
- # Validate lengths match
301
- if len (token_ids_list ) != len (prompt_lens ):
302
- logger .error (f"Mismatch: { len (token_ids_list )} sequences vs { len (prompt_lens )} lengths" )
303
- return self ._base_pooler (hidden_states , pooling_metadata )
304
-
305
- # Apply optimized pooling
306
- try :
307
- pooled_data = self ._apply_vision_pooling_optimized (
308
- hidden_states , token_ids_list , prompt_lens
309
- )
310
- except RuntimeError as e :
311
- if "out of memory" in str (e ).lower ():
312
- logger .warning ("OOM during pooling, falling back to sequential processing" )
313
- # Process sequences one by one to reduce memory
314
- pooled_data = []
315
- for i in range (len (token_ids_list )):
316
- single_pooled = self ._apply_vision_pooling_pytorch (
317
- hidden_states ,
318
- [token_ids_list [i ]],
319
- prompt_lens [i :i + 1 ]
320
- )
321
- pooled_data .extend (single_pooled )
322
- else :
323
- raise
324
-
325
- # Build output
326
- pooled_outputs = [
327
- PoolingSequenceGroupOutput (data ) for data in pooled_data
328
- ]
329
-
330
- # Record metrics
331
- if self .observability_config :
332
- elapsed_ms = (time .time () - start_time ) * 1000
333
- self ._pooling_time_ms += elapsed_ms
334
- self ._pooling_count += 1
335
-
336
- if self ._pooling_count % 100 == 0 :
337
- avg_time = self ._pooling_time_ms / self ._pooling_count
338
- logger .debug (f"Average pooling time: { avg_time :.2f} ms" )
339
-
340
- return PoolerOutput (outputs = pooled_outputs )
306
+ pooled_data = self ._apply_vision_pooling_optimized (
307
+ hidden_states , token_ids_list , prompt_lens
308
+ )
309
+ except RuntimeError as e :
310
+ if "out of memory" in str (e ).lower ():
311
+ logger .warning ("OOM during pooling, falling back to sequential processing" )
312
+ # Process sequences one by one to reduce memory
313
+ pooled_data = []
314
+ for i in range (len (token_ids_list )):
315
+ single_pooled = self ._apply_vision_pooling_pytorch (
316
+ hidden_states ,
317
+ [token_ids_list [i ]],
318
+ prompt_lens [i :i + 1 ]
319
+ )
320
+ pooled_data .extend (single_pooled )
321
+ else :
322
+ raise
323
+
324
+ # Build output
325
+ pooled_outputs = [
326
+ PoolingSequenceGroupOutput (data ) for data in pooled_data
327
+ ]
328
+
329
+ # Record metrics
330
+ if self .observability_config :
331
+ elapsed_ms = (time .time () - start_time ) * 1000
332
+ self ._pooling_time_ms += elapsed_ms
333
+ self ._pooling_count += 1
341
334
342
- except Exception as e :
343
- logger .error (f"Error in pooler: { type (e ).__name__ } : { e } " )
344
- # Graceful degradation to base pooler
345
- logger .info ("Falling back to base pooler due to error" )
346
- return self ._base_pooler (hidden_states , pooling_metadata )
335
+ if self ._pooling_count % 100 == 0 :
336
+ avg_time = self ._pooling_time_ms / self ._pooling_count
337
+ logger .debug (f"Average pooling time: { avg_time :.2f} ms" )
347
338
348
- finally :
349
- # Rely on Python's garbage collector for releasing tensors.
350
- # torch.cuda.empty_cache() is a blocking and expensive operation
351
- # that should be used sparingly.
352
- pass
339
+ return PoolerOutput (outputs = pooled_outputs )
353
340
354
341
def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
355
342
"""Load weights with validation and error handling."""
0 commit comments