36
36
from vllm .model_executor .model_loader .weight_utils import default_weight_loader
37
37
from vllm .model_executor .models .qwen2_5_vl import (
38
38
Qwen2_5_VisionAttention , Qwen2_5_VisionBlock , Qwen2_5_VisionPatchEmbed ,
39
- Qwen2_5_VisionTransformer , Qwen2_5_VLDummyInputsBuilder ,
40
- Qwen2_5_VLForConditionalGeneration , Qwen2_5_VLMultiModalProcessor ,
41
- Qwen2_5_VLProcessingInfo )
39
+ Qwen2_5_VisionRotaryEmbedding , Qwen2_5_VisionTransformer ,
40
+ Qwen2_5_VLDummyInputsBuilder , Qwen2_5_VLForConditionalGeneration ,
41
+ Qwen2_5_VLMultiModalProcessor , Qwen2_5_VLProcessingInfo )
42
42
from vllm .model_executor .models .utils import maybe_prefix
43
43
from vllm .multimodal import MULTIMODAL_REGISTRY
44
44
@@ -152,6 +152,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
152
152
return x
153
153
154
154
155
+ class AscendQwen2_5_VisionRotaryEmbedding (Qwen2_5_VisionRotaryEmbedding ):
156
+
157
+ def __init__ (self , dim : int , theta : float = 10000.0 ) -> None :
158
+ super ().__init__ (dim , theta )
159
+ inv_freq = 1.0 / (theta
160
+ ** (torch .arange (0 , dim , 2 , dtype = torch .float ) / dim ))
161
+ self .inv_freq = inv_freq
162
+
163
+
155
164
class AscendQwen2_5_VisionTransformer (Qwen2_5_VisionTransformer ):
156
165
157
166
def __init__ (
@@ -166,6 +175,9 @@ def __init__(
166
175
norm_layer = partial (RMSNorm , eps = norm_eps )
167
176
self .interleaved = interleaved
168
177
self .enable_pad = False
178
+ head_dim = self .hidden_size // self .num_heads
179
+ self .rotary_pos_emb = AscendQwen2_5_VisionRotaryEmbedding (head_dim //
180
+ 2 )
169
181
self .patch_embed = AscendQwen2_5_VisionPatchEmbed (
170
182
patch_size = vision_config .patch_size ,
171
183
temporal_patch_size = vision_config .temporal_patch_size ,
@@ -298,6 +310,66 @@ def load_weights(self, weights: Iterable[Tuple[str,
298
310
loaded_params .add (name )
299
311
return loaded_params
300
312
313
+ def rot_pos_emb (self , grid_thw : torch .Tensor ) -> torch .Tensor :
314
+ pos_ids = []
315
+ for t , h , w in grid_thw :
316
+ hpos_ids = torch .arange (h ).unsqueeze (1 ).expand (- 1 , w )
317
+ wpos_ids = torch .arange (w ).unsqueeze (0 ).expand (h , - 1 )
318
+ hpos_ids = hpos_ids .reshape (
319
+ h // self .spatial_merge_size ,
320
+ self .spatial_merge_size ,
321
+ w // self .spatial_merge_size ,
322
+ self .spatial_merge_size ,
323
+ ).permute (0 , 2 , 1 , 3 ).flatten ()
324
+ wpos_ids = wpos_ids .reshape (
325
+ h // self .spatial_merge_size ,
326
+ self .spatial_merge_size ,
327
+ w // self .spatial_merge_size ,
328
+ self .spatial_merge_size ,
329
+ ).permute (0 , 2 , 1 , 3 ).flatten ()
330
+ pos_ids .append (
331
+ torch .stack ([hpos_ids , wpos_ids ], dim = - 1 ).repeat (t , 1 ))
332
+ pos_ids = torch .cat (pos_ids , dim = 0 )
333
+ max_grid_size = grid_thw [:, 1 :].max ()
334
+ rotary_pos_emb_full = self .rotary_pos_emb (max_grid_size )
335
+ rotary_pos_emb = rotary_pos_emb_full [pos_ids ].flatten (1 )
336
+ return rotary_pos_emb
337
+
338
+ def get_window_index (self , grid_thw ):
339
+ window_index : list = []
340
+ cu_window_seqlens : list = [0 ]
341
+ window_index_id = 0
342
+ vit_merger_window_size = (self .window_size //
343
+ self .spatial_merge_size // self .patch_size )
344
+
345
+ for grid_t , grid_h , grid_w in grid_thw :
346
+ llm_grid_h = grid_h // self .spatial_merge_size
347
+ llm_grid_w = grid_w // self .spatial_merge_size
348
+ index = torch .arange (grid_t * llm_grid_h * llm_grid_w ).reshape (
349
+ grid_t , llm_grid_h , llm_grid_w )
350
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
351
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
352
+ num_windows_h = (llm_grid_h + pad_h ) // vit_merger_window_size
353
+ num_windows_w = (llm_grid_w + pad_w ) // vit_merger_window_size
354
+ index_padded = F .pad (index , (0 , pad_w , 0 , pad_h ), 'constant' , - 100 )
355
+ index_padded = index_padded .reshape (grid_t , num_windows_h ,
356
+ vit_merger_window_size ,
357
+ num_windows_w ,
358
+ vit_merger_window_size )
359
+ index_padded = index_padded .permute (0 , 1 , 3 , 2 , 4 ).reshape (
360
+ grid_t , num_windows_h * num_windows_w , vit_merger_window_size ,
361
+ vit_merger_window_size )
362
+ seqlens = (index_padded != - 100 ).sum ([2 , 3 ]).reshape (- 1 )
363
+ index_padded = index_padded .reshape (- 1 )
364
+ index_new = index_padded [index_padded != - 100 ]
365
+ window_index .append (index_new + window_index_id )
366
+ cu_seqlens_tmp = seqlens .cumsum (
367
+ 0 ) * self .spatial_merge_unit + cu_window_seqlens [- 1 ]
368
+ cu_window_seqlens .extend (cu_seqlens_tmp .tolist ())
369
+ window_index_id += (grid_t * llm_grid_h * llm_grid_w ).item ()
370
+ window_index = torch .cat (window_index , dim = 0 )
371
+ return window_index , cu_window_seqlens
372
+
301
373
def forward (
302
374
self ,
303
375
x : torch .Tensor ,
@@ -366,4 +438,37 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
366
438
norm_eps = getattr (config , "rms_norm_eps" , 1e-6 ),
367
439
quant_config = self ._maybe_ignore_quant_config (quant_config ),
368
440
prefix = maybe_prefix (prefix , "visual" ),
369
- )
441
+ )
442
+
443
+ def _process_image_input (self , image_input ) -> tuple [torch .Tensor , ...]:
444
+
445
+ grid_thw = image_input ["image_grid_thw" ]
446
+ assert grid_thw .ndim == 2
447
+
448
+ if image_input ["type" ] == "image_embeds" :
449
+ image_embeds = image_input ["image_embeds" ].type (self .visual .dtype )
450
+ else :
451
+ pixel_values = image_input ["pixel_values" ].type (self .visual .dtype )
452
+ image_embeds = self .visual (pixel_values , grid_thw = grid_thw )
453
+
454
+ # Split concatenated embeddings for each image item.
455
+ merge_size = self .visual .spatial_merge_size
456
+ sizes = grid_thw .prod (- 1 ) // merge_size // merge_size
457
+ return image_embeds .split (sizes .tolist ())
458
+
459
+ def _process_video_input (self , video_input ) -> tuple [torch .Tensor , ...]:
460
+
461
+ grid_thw = video_input ["video_grid_thw" ]
462
+ assert grid_thw .ndim == 2
463
+
464
+ if video_input ["type" ] == "video_embeds" :
465
+ video_embeds = video_input ["video_embeds" ].type (self .visual .dtype )
466
+ else :
467
+ pixel_values_videos = video_input ["pixel_values_videos" ].type (
468
+ self .visual .dtype )
469
+ video_embeds = self .visual (pixel_values_videos , grid_thw = grid_thw )
470
+
471
+ # Split concatenated embeddings for each video item.
472
+ merge_size = self .visual .spatial_merge_size
473
+ sizes = grid_thw .prod (- 1 ) // merge_size // merge_size
474
+ return video_embeds .split (sizes .tolist ())
0 commit comments