39
39
VISION_START_TOKEN_ID = 151652
40
40
VISION_END_TOKEN_ID = 151653
41
41
42
- # Maximum sequence length for safety
43
- MAX_SEQUENCE_LENGTH = 512 * 1024 # 512K tokens
44
-
45
42
46
43
PoolingMetadata = Union [V0PoolingMetadata , V1PoolingMetadata ]
47
44
@@ -227,16 +224,10 @@ def _apply_vision_pooling_optimized(
227
224
# Regular mean pooling for text
228
225
seq_states = hidden_states [offset :offset + prompt_len ]
229
226
output = seq_states .mean (dim = 0 )
230
-
231
- # Normalize (check for zero vector to avoid NaN)
232
- if output .count_nonzero () > 0 :
233
- output = F .normalize (output , p = 2 , dim = - 1 )
234
- else :
235
- # If all zeros, fall back to PyTorch implementation
236
- logger .warning ("Triton kernel returned zero vector, falling back to PyTorch" )
237
- seq_states = hidden_states [offset :offset + prompt_len ]
238
- output = seq_states .mean (dim = 0 )
239
- output = F .normalize (output , p = 2 , dim = - 1 )
227
+
228
+ # Normalize and handle potential NaNs by replacing with zeros
229
+ output = F .normalize (output , p = 2 , dim = - 1 )
230
+ output = torch .nan_to_num (output )
240
231
pooled_outputs .append (output )
241
232
242
233
offset += prompt_len
@@ -256,11 +247,6 @@ def _apply_vision_pooling_pytorch(
256
247
for token_ids , prompt_len in zip (token_ids_list , prompt_lens ):
257
248
prompt_len = int (prompt_len .item ())
258
249
259
- # Safety check for sequence length
260
- if prompt_len > MAX_SEQUENCE_LENGTH :
261
- logger .warning (f"Sequence length { prompt_len } exceeds maximum { MAX_SEQUENCE_LENGTH } " )
262
- prompt_len = MAX_SEQUENCE_LENGTH
263
-
264
250
# Extract sequence states and tokens
265
251
seq_states = hidden_states [offset :offset + prompt_len ]
266
252
@@ -325,10 +311,9 @@ def pooler(
325
311
).prompt_lens
326
312
327
313
# Validate lengths match
328
- if len (token_ids_list ) != len (prompt_lens ):
329
- raise AssertionError (
330
- f"Mismatch: { len (token_ids_list )} sequences vs { len (prompt_lens )} lengths"
331
- )
314
+ assert len (token_ids_list ) == len (prompt_lens ), (
315
+ f"Mismatch: { len (token_ids_list )} sequences vs { len (prompt_lens )} lengths"
316
+ )
332
317
333
318
# Apply pooling based on configured backend
334
319
if self .pooling_backend == "triton" :
0 commit comments