Skip to content

Commit 54839c7

Browse files
committed
refactor: review
1 parent d0d7b26 commit 54839c7

File tree

2 files changed

+45
-47
lines changed

2 files changed

+45
-47
lines changed

vllm/model_executor/layers/pooler.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,3 +471,46 @@ def forward(
471471

472472
pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores]
473473
return PoolerOutput(outputs=pooled_outputs)
474+
475+
from vllm.triton_utils import tl, triton
476+
HAS_TRITON = triton is not None
477+
478+
if HAS_TRITON:
479+
@triton.jit
480+
def extract_vision_tokens_kernel(
481+
hidden_states_ptr,
482+
token_ids_ptr,
483+
output_ptr,
484+
seq_start,
485+
seq_len,
486+
hidden_size,
487+
vision_start_id: tl.constexpr,
488+
vision_end_id: tl.constexpr,
489+
BLOCK_SIZE: tl.constexpr,
490+
):
491+
"""Triton kernel to extract and pool vision tokens efficiently."""
492+
pid = tl.program_id(0)
493+
494+
if pid >= hidden_size:
495+
return
496+
497+
# Find vision token range
498+
vision_count = 0
499+
accumulator = 0.0
500+
501+
for i in range(seq_len):
502+
token_id = tl.load(token_ids_ptr + seq_start + i)
503+
if token_id >= vision_start_id and token_id <= vision_end_id:
504+
hidden_val = tl.load(
505+
hidden_states_ptr + (seq_start + i) * hidden_size + pid
506+
)
507+
accumulator += hidden_val
508+
vision_count += 1
509+
510+
# Store mean pooled result
511+
if vision_count > 0:
512+
result = accumulator / vision_count
513+
else:
514+
result = 0.0
515+
516+
tl.store(output_ptr + pid, result)

vllm/model_executor/models/jina_embeddings_v4.py

Lines changed: 2 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,7 @@
99
import torch.nn.functional as F
1010
from torch import nn
1111

12-
try:
13-
import triton
14-
import triton.language as tl
15-
HAS_TRITON = True
16-
except ImportError:
17-
HAS_TRITON = False
18-
triton = None
19-
tl = None
12+
from vllm.model_executor.layers.pooler import HAS_TRITON, extract_vision_tokens_kernel
2013

2114
from vllm.config import VllmConfig
2215
from vllm.logger import init_logger
@@ -44,45 +37,7 @@
4437

4538

4639
# Triton kernel for optimized vision token extraction
47-
if HAS_TRITON:
48-
@triton.jit
49-
def extract_vision_tokens_kernel(
50-
hidden_states_ptr,
51-
token_ids_ptr,
52-
output_ptr,
53-
seq_start,
54-
seq_len,
55-
hidden_size,
56-
vision_start_id: tl.constexpr,
57-
vision_end_id: tl.constexpr,
58-
BLOCK_SIZE: tl.constexpr,
59-
):
60-
"""Triton kernel to extract and pool vision tokens efficiently."""
61-
pid = tl.program_id(0)
62-
63-
if pid >= hidden_size:
64-
return
65-
66-
# Find vision token range
67-
vision_count = 0
68-
accumulator = 0.0
69-
70-
for i in range(seq_len):
71-
token_id = tl.load(token_ids_ptr + seq_start + i)
72-
if token_id >= vision_start_id and token_id <= vision_end_id:
73-
hidden_val = tl.load(
74-
hidden_states_ptr + (seq_start + i) * hidden_size + pid
75-
)
76-
accumulator += hidden_val
77-
vision_count += 1
78-
79-
# Store mean pooled result
80-
if vision_count > 0:
81-
result = accumulator / vision_count
82-
else:
83-
result = 0.0
84-
85-
tl.store(output_ptr + pid, result)
40+
8641

8742

8843
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor,

0 commit comments

Comments
 (0)