File tree Expand file tree Collapse file tree 2 files changed +45
-47
lines changed Expand file tree Collapse file tree 2 files changed +45
-47
lines changed Original file line number Diff line number Diff line change @@ -469,5 +469,48 @@ def forward(
469
469
pooled_output )
470
470
])
471
471
472
+ from vllm .triton_utils import tl , triton
473
+ HAS_TRITON = triton is not None
474
+
472
475
pooled_outputs = [PoolingSequenceGroupOutput (data ) for data in scores ]
473
476
return PoolerOutput (outputs = pooled_outputs )
477
+
478
+ if HAS_TRITON :
479
+ @triton .jit
480
+ def extract_vision_tokens_kernel (
481
+ hidden_states_ptr ,
482
+ token_ids_ptr ,
483
+ output_ptr ,
484
+ seq_start ,
485
+ seq_len ,
486
+ hidden_size ,
487
+ vision_start_id : tl .constexpr ,
488
+ vision_end_id : tl .constexpr ,
489
+ BLOCK_SIZE : tl .constexpr ,
490
+ ):
491
+ """Triton kernel to extract and pool vision tokens efficiently."""
492
+ pid = tl .program_id (0 )
493
+
494
+ if pid >= hidden_size :
495
+ return
496
+
497
+ # Find vision token range
498
+ vision_count = 0
499
+ accumulator = 0.0
500
+
501
+ for i in range (seq_len ):
502
+ token_id = tl .load (token_ids_ptr + seq_start + i )
503
+ if token_id >= vision_start_id and token_id <= vision_end_id :
504
+ hidden_val = tl .load (
505
+ hidden_states_ptr + (seq_start + i ) * hidden_size + pid
506
+ )
507
+ accumulator += hidden_val
508
+ vision_count += 1
509
+
510
+ # Store mean pooled result
511
+ if vision_count > 0 :
512
+ result = accumulator / vision_count
513
+ else :
514
+ result = 0.0
515
+
516
+ tl .store (output_ptr + pid , result )
Original file line number Diff line number Diff line change 9
9
import torch .nn .functional as F
10
10
from torch import nn
11
11
12
- try :
13
- import triton
14
- import triton .language as tl
15
- HAS_TRITON = True
16
- except ImportError :
17
- HAS_TRITON = False
18
- triton = None
19
- tl = None
12
+ from vllm .model_executor .layers .pooler import HAS_TRITON , extract_vision_tokens_kernel
20
13
21
14
from vllm .config import VllmConfig
22
15
from vllm .logger import init_logger
44
37
45
38
46
39
# Triton kernel for optimized vision token extraction
47
- if HAS_TRITON :
48
- @triton .jit
49
- def extract_vision_tokens_kernel (
50
- hidden_states_ptr ,
51
- token_ids_ptr ,
52
- output_ptr ,
53
- seq_start ,
54
- seq_len ,
55
- hidden_size ,
56
- vision_start_id : tl .constexpr ,
57
- vision_end_id : tl .constexpr ,
58
- BLOCK_SIZE : tl .constexpr ,
59
- ):
60
- """Triton kernel to extract and pool vision tokens efficiently."""
61
- pid = tl .program_id (0 )
62
-
63
- if pid >= hidden_size :
64
- return
65
-
66
- # Find vision token range
67
- vision_count = 0
68
- accumulator = 0.0
69
-
70
- for i in range (seq_len ):
71
- token_id = tl .load (token_ids_ptr + seq_start + i )
72
- if token_id >= vision_start_id and token_id <= vision_end_id :
73
- hidden_val = tl .load (
74
- hidden_states_ptr + (seq_start + i ) * hidden_size + pid
75
- )
76
- accumulator += hidden_val
77
- vision_count += 1
78
-
79
- # Store mean pooled result
80
- if vision_count > 0 :
81
- result = accumulator / vision_count
82
- else :
83
- result = 0.0
84
-
85
- tl .store (output_ptr + pid , result )
40
+
86
41
87
42
88
43
@MULTIMODAL_REGISTRY .register_processor (Qwen2VLMultiModalProcessor ,
You can’t perform that action at this time.
0 commit comments