File tree Expand file tree Collapse file tree 2 files changed +46
-47
lines changed Expand file tree Collapse file tree 2 files changed +46
-47
lines changed Original file line number Diff line number Diff line change 19
19
get_cross_encoder_activation_function )
20
20
from vllm .v1 .pool .metadata import PoolingMetadata as V1PoolingMetadata
21
21
22
+ from vllm .triton_utils import tl , triton
23
+ HAS_TRITON = triton is not None
24
+
25
+
22
26
PoolingMetadata = Union [V0PoolingMetadata , V1PoolingMetadata ]
23
27
24
28
@@ -471,3 +475,43 @@ def forward(
471
475
472
476
pooled_outputs = [PoolingSequenceGroupOutput (data ) for data in scores ]
473
477
return PoolerOutput (outputs = pooled_outputs )
478
+
479
+ if HAS_TRITON :
480
+ @triton .jit
481
+ def extract_vision_tokens_kernel (
482
+ hidden_states_ptr ,
483
+ token_ids_ptr ,
484
+ output_ptr ,
485
+ seq_start ,
486
+ seq_len ,
487
+ hidden_size ,
488
+ vision_start_id : tl .constexpr ,
489
+ vision_end_id : tl .constexpr ,
490
+ BLOCK_SIZE : tl .constexpr ,
491
+ ):
492
+ """Triton kernel to extract and pool vision tokens efficiently."""
493
+ pid = tl .program_id (0 )
494
+
495
+ if pid >= hidden_size :
496
+ return
497
+
498
+ # Find vision token range
499
+ vision_count = 0
500
+ accumulator = 0.0
501
+
502
+ for i in range (seq_len ):
503
+ token_id = tl .load (token_ids_ptr + seq_start + i )
504
+ if token_id >= vision_start_id and token_id <= vision_end_id :
505
+ hidden_val = tl .load (
506
+ hidden_states_ptr + (seq_start + i ) * hidden_size + pid
507
+ )
508
+ accumulator += hidden_val
509
+ vision_count += 1
510
+
511
+ # Store mean pooled result
512
+ if vision_count > 0 :
513
+ result = accumulator / vision_count
514
+ else :
515
+ result = 0.0
516
+
517
+ tl .store (output_ptr + pid , result )
Original file line number Diff line number Diff line change 9
9
import torch .nn .functional as F
10
10
from torch import nn
11
11
12
- try :
13
- import triton
14
- import triton .language as tl
15
- HAS_TRITON = True
16
- except ImportError :
17
- HAS_TRITON = False
18
- triton = None
19
- tl = None
12
+ from vllm .model_executor .layers .pooler import HAS_TRITON , extract_vision_tokens_kernel
20
13
21
14
from vllm .config import VllmConfig
22
15
from vllm .logger import init_logger
44
37
45
38
46
39
# Triton kernel for optimized vision token extraction
47
- if HAS_TRITON :
48
- @triton .jit
49
- def extract_vision_tokens_kernel (
50
- hidden_states_ptr ,
51
- token_ids_ptr ,
52
- output_ptr ,
53
- seq_start ,
54
- seq_len ,
55
- hidden_size ,
56
- vision_start_id : tl .constexpr ,
57
- vision_end_id : tl .constexpr ,
58
- BLOCK_SIZE : tl .constexpr ,
59
- ):
60
- """Triton kernel to extract and pool vision tokens efficiently."""
61
- pid = tl .program_id (0 )
62
-
63
- if pid >= hidden_size :
64
- return
65
-
66
- # Find vision token range
67
- vision_count = 0
68
- accumulator = 0.0
69
-
70
- for i in range (seq_len ):
71
- token_id = tl .load (token_ids_ptr + seq_start + i )
72
- if token_id >= vision_start_id and token_id <= vision_end_id :
73
- hidden_val = tl .load (
74
- hidden_states_ptr + (seq_start + i ) * hidden_size + pid
75
- )
76
- accumulator += hidden_val
77
- vision_count += 1
78
-
79
- # Store mean pooled result
80
- if vision_count > 0 :
81
- result = accumulator / vision_count
82
- else :
83
- result = 0.0
84
-
85
- tl .store (output_ptr + pid , result )
40
+
86
41
87
42
88
43
@MULTIMODAL_REGISTRY .register_processor (Qwen2VLMultiModalProcessor ,
You can’t perform that action at this time.
0 commit comments