Skip to content

Commit 4a83cd9

Browse files
committed
refactor: normalize
1 parent 714e283 commit 4a83cd9

File tree

1 file changed

+7
-22
lines changed

1 file changed

+7
-22
lines changed

vllm/model_executor/models/jina_embeddings_v4.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,6 @@
3939
VISION_START_TOKEN_ID = 151652
4040
VISION_END_TOKEN_ID = 151653
4141

42-
# Maximum sequence length for safety
43-
MAX_SEQUENCE_LENGTH = 512 * 1024 # 512K tokens
44-
4542

4643
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
4744

@@ -227,16 +224,10 @@ def _apply_vision_pooling_optimized(
227224
# Regular mean pooling for text
228225
seq_states = hidden_states[offset:offset + prompt_len]
229226
output = seq_states.mean(dim=0)
230-
231-
# Normalize (check for zero vector to avoid NaN)
232-
if output.count_nonzero() > 0:
233-
output = F.normalize(output, p=2, dim=-1)
234-
else:
235-
# If all zeros, fall back to PyTorch implementation
236-
logger.warning("Triton kernel returned zero vector, falling back to PyTorch")
237-
seq_states = hidden_states[offset:offset + prompt_len]
238-
output = seq_states.mean(dim=0)
239-
output = F.normalize(output, p=2, dim=-1)
227+
228+
# Normalize and handle potential NaNs by replacing with zeros
229+
output = F.normalize(output, p=2, dim=-1)
230+
output = torch.nan_to_num(output)
240231
pooled_outputs.append(output)
241232

242233
offset += prompt_len
@@ -256,11 +247,6 @@ def _apply_vision_pooling_pytorch(
256247
for token_ids, prompt_len in zip(token_ids_list, prompt_lens):
257248
prompt_len = int(prompt_len.item())
258249

259-
# Safety check for sequence length
260-
if prompt_len > MAX_SEQUENCE_LENGTH:
261-
logger.warning(f"Sequence length {prompt_len} exceeds maximum {MAX_SEQUENCE_LENGTH}")
262-
prompt_len = MAX_SEQUENCE_LENGTH
263-
264250
# Extract sequence states and tokens
265251
seq_states = hidden_states[offset:offset + prompt_len]
266252

@@ -325,10 +311,9 @@ def pooler(
325311
).prompt_lens
326312

327313
# Validate lengths match
328-
if len(token_ids_list) != len(prompt_lens):
329-
raise AssertionError(
330-
f"Mismatch: {len(token_ids_list)} sequences vs {len(prompt_lens)} lengths"
331-
)
314+
assert len(token_ids_list) == len(prompt_lens), (
315+
f"Mismatch: {len(token_ids_list)} sequences vs {len(prompt_lens)} lengths"
316+
)
332317

333318
# Apply pooling based on configured backend
334319
if self.pooling_backend == "triton":

0 commit comments

Comments
 (0)