Skip to content

Commit caea1fe

Browse files
committed
refactor: review
1 parent d0d7b26 commit caea1fe

File tree

2 files changed

+46
-47
lines changed

2 files changed

+46
-47
lines changed

vllm/model_executor/layers/pooler.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
get_cross_encoder_activation_function)
2020
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
2121

22+
from vllm.triton_utils import tl, triton
23+
HAS_TRITON = triton is not None
24+
25+
2226
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
2327

2428

@@ -471,3 +475,43 @@ def forward(
471475

472476
pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores]
473477
return PoolerOutput(outputs=pooled_outputs)
478+
479+
if HAS_TRITON:
480+
@triton.jit
481+
def extract_vision_tokens_kernel(
482+
hidden_states_ptr,
483+
token_ids_ptr,
484+
output_ptr,
485+
seq_start,
486+
seq_len,
487+
hidden_size,
488+
vision_start_id: tl.constexpr,
489+
vision_end_id: tl.constexpr,
490+
BLOCK_SIZE: tl.constexpr,
491+
):
492+
"""Triton kernel to extract and pool vision tokens efficiently."""
493+
pid = tl.program_id(0)
494+
495+
if pid >= hidden_size:
496+
return
497+
498+
# Find vision token range
499+
vision_count = 0
500+
accumulator = 0.0
501+
502+
for i in range(seq_len):
503+
token_id = tl.load(token_ids_ptr + seq_start + i)
504+
if token_id >= vision_start_id and token_id <= vision_end_id:
505+
hidden_val = tl.load(
506+
hidden_states_ptr + (seq_start + i) * hidden_size + pid
507+
)
508+
accumulator += hidden_val
509+
vision_count += 1
510+
511+
# Store mean pooled result
512+
if vision_count > 0:
513+
result = accumulator / vision_count
514+
else:
515+
result = 0.0
516+
517+
tl.store(output_ptr + pid, result)

vllm/model_executor/models/jina_embeddings_v4.py

Lines changed: 2 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,7 @@
99
import torch.nn.functional as F
1010
from torch import nn
1111

12-
try:
13-
import triton
14-
import triton.language as tl
15-
HAS_TRITON = True
16-
except ImportError:
17-
HAS_TRITON = False
18-
triton = None
19-
tl = None
12+
from vllm.model_executor.layers.pooler import HAS_TRITON, extract_vision_tokens_kernel
2013

2114
from vllm.config import VllmConfig
2215
from vllm.logger import init_logger
@@ -44,45 +37,7 @@
4437

4538

4639
# Triton kernel for optimized vision token extraction
47-
if HAS_TRITON:
48-
@triton.jit
49-
def extract_vision_tokens_kernel(
50-
hidden_states_ptr,
51-
token_ids_ptr,
52-
output_ptr,
53-
seq_start,
54-
seq_len,
55-
hidden_size,
56-
vision_start_id: tl.constexpr,
57-
vision_end_id: tl.constexpr,
58-
BLOCK_SIZE: tl.constexpr,
59-
):
60-
"""Triton kernel to extract and pool vision tokens efficiently."""
61-
pid = tl.program_id(0)
62-
63-
if pid >= hidden_size:
64-
return
65-
66-
# Find vision token range
67-
vision_count = 0
68-
accumulator = 0.0
69-
70-
for i in range(seq_len):
71-
token_id = tl.load(token_ids_ptr + seq_start + i)
72-
if token_id >= vision_start_id and token_id <= vision_end_id:
73-
hidden_val = tl.load(
74-
hidden_states_ptr + (seq_start + i) * hidden_size + pid
75-
)
76-
accumulator += hidden_val
77-
vision_count += 1
78-
79-
# Store mean pooled result
80-
if vision_count > 0:
81-
result = accumulator / vision_count
82-
else:
83-
result = 0.0
84-
85-
tl.store(output_ptr + pid, result)
40+
8641

8742

8843
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor,

0 commit comments

Comments
 (0)