25
25
# limitations under the License.
26
26
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
27
27
from collections .abc import Iterable , Mapping
28
- from functools import partial
28
+ from functools import lru_cache , partial
29
29
from typing import Callable , Literal , Optional , TypedDict , Union
30
30
31
31
import torch
@@ -478,8 +478,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None:
478
478
super ().__init__ ()
479
479
self .dim = dim
480
480
self .theta = theta
481
- inv_freq = 1.0 / (theta
482
- ** ( torch .arange (0 , dim , 2 , dtype = torch .float ) / dim ))
481
+ inv_freq = 1.0 / (theta ** (
482
+ torch .arange (0 , dim , 2 , dtype = torch .float , device = 'cpu' ) / dim ))
483
483
self .register_buffer ("inv_freq" , inv_freq , persistent = False )
484
484
self ._seq_len_cached = 0
485
485
self ._freqs_cached = None
@@ -520,7 +520,7 @@ def __init__(
520
520
self .hidden_size = vision_config .hidden_size
521
521
self .num_heads = vision_config .num_heads
522
522
523
- # args for get_window_index
523
+ # args for get_window_index_thw
524
524
self .window_size = vision_config .window_size
525
525
self .patch_size = vision_config .patch_size
526
526
self .spatial_merge_size = vision_config .spatial_merge_size
@@ -567,65 +567,71 @@ def dtype(self) -> torch.dtype:
567
567
def device (self ) -> torch .device :
568
568
return self .patch_embed .proj .weight .device
569
569
570
- def rot_pos_emb (self , grid_thw : torch .Tensor ) -> torch .Tensor :
571
- pos_ids = []
572
- for t , h , w in grid_thw :
573
- hpos_ids = torch .arange (h ).unsqueeze (1 ).expand (- 1 , w )
574
- wpos_ids = torch .arange (w ).unsqueeze (0 ).expand (h , - 1 )
575
- hpos_ids = hpos_ids .reshape (
576
- h // self .spatial_merge_size ,
577
- self .spatial_merge_size ,
578
- w // self .spatial_merge_size ,
579
- self .spatial_merge_size ,
580
- ).permute (0 , 2 , 1 , 3 ).flatten ()
581
- wpos_ids = wpos_ids .reshape (
582
- h // self .spatial_merge_size ,
583
- self .spatial_merge_size ,
584
- w // self .spatial_merge_size ,
585
- self .spatial_merge_size ,
586
- ).permute (0 , 2 , 1 , 3 ).flatten ()
587
- pos_ids .append (
588
- torch .stack ([hpos_ids , wpos_ids ], dim = - 1 ).repeat (t , 1 ))
589
- pos_ids = torch .cat (pos_ids , dim = 0 )
590
- max_grid_size = grid_thw [:, 1 :].max ()
591
- rotary_pos_emb_full = self .rotary_pos_emb (max_grid_size )
570
+ def rotary_pos_emb_thw (self , t , h , w ):
571
+ hpos_ids = torch .arange (h ).unsqueeze (1 ).expand (- 1 , w )
572
+ wpos_ids = torch .arange (w ).unsqueeze (0 ).expand (h , - 1 )
573
+ hpos_ids = hpos_ids .reshape (
574
+ h // self .spatial_merge_size ,
575
+ self .spatial_merge_size ,
576
+ w // self .spatial_merge_size ,
577
+ self .spatial_merge_size ,
578
+ ).permute (0 , 2 , 1 , 3 ).flatten ()
579
+ wpos_ids = wpos_ids .reshape (
580
+ h // self .spatial_merge_size ,
581
+ self .spatial_merge_size ,
582
+ w // self .spatial_merge_size ,
583
+ self .spatial_merge_size ,
584
+ ).permute (0 , 2 , 1 , 3 ).flatten ()
585
+ pos_ids = torch .stack ([hpos_ids , wpos_ids ], dim = - 1 ).repeat (t , 1 )
586
+ max_size = max (h , w )
587
+ rotary_pos_emb_full = self .rotary_pos_emb (max_size )
592
588
rotary_pos_emb = rotary_pos_emb_full [pos_ids ].flatten (1 )
589
+ rotary_pos_emb = rotary_pos_emb .reshape (
590
+ rotary_pos_emb .shape [0 ] // self .spatial_merge_unit ,
591
+ self .spatial_merge_unit , - 1 )
592
+
593
593
return rotary_pos_emb
594
594
595
- def get_window_index (self , grid_thw ):
596
- window_index : list = []
597
- cu_window_seqlens : list = [0 ]
598
- window_index_id = 0
595
+ def get_window_index_thw (self , grid_t , grid_h , grid_w ):
599
596
vit_merger_window_size = (self .window_size //
600
597
self .spatial_merge_size // self .patch_size )
601
598
602
- for grid_t , grid_h , grid_w in grid_thw :
603
- llm_grid_h = grid_h // self .spatial_merge_size
604
- llm_grid_w = grid_w // self .spatial_merge_size
605
- index = torch .arange (grid_t * llm_grid_h * llm_grid_w ).reshape (
606
- grid_t , llm_grid_h , llm_grid_w )
607
- pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
608
- pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
609
- num_windows_h = (llm_grid_h + pad_h ) // vit_merger_window_size
610
- num_windows_w = (llm_grid_w + pad_w ) // vit_merger_window_size
611
- index_padded = F .pad (index , (0 , pad_w , 0 , pad_h ), 'constant' , - 100 )
612
- index_padded = index_padded .reshape (grid_t , num_windows_h ,
613
- vit_merger_window_size ,
614
- num_windows_w ,
615
- vit_merger_window_size )
616
- index_padded = index_padded .permute (0 , 1 , 3 , 2 , 4 ).reshape (
617
- grid_t , num_windows_h * num_windows_w , vit_merger_window_size ,
618
- vit_merger_window_size )
619
- seqlens = (index_padded != - 100 ).sum ([2 , 3 ]).reshape (- 1 )
620
- index_padded = index_padded .reshape (- 1 )
621
- index_new = index_padded [index_padded != - 100 ]
622
- window_index .append (index_new + window_index_id )
623
- cu_seqlens_tmp = seqlens .cumsum (
624
- 0 ) * self .spatial_merge_unit + cu_window_seqlens [- 1 ]
625
- cu_window_seqlens .extend (cu_seqlens_tmp .tolist ())
626
- window_index_id += (grid_t * llm_grid_h * llm_grid_w ).item ()
627
- window_index = torch .cat (window_index , dim = 0 )
628
- return window_index , cu_window_seqlens
599
+ llm_grid_h = grid_h // self .spatial_merge_size
600
+ llm_grid_w = grid_w // self .spatial_merge_size
601
+ index = torch .arange (grid_t * llm_grid_h * llm_grid_w ).reshape (
602
+ grid_t , llm_grid_h , llm_grid_w )
603
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
604
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
605
+ num_windows_h = (llm_grid_h + pad_h ) // vit_merger_window_size
606
+ num_windows_w = (llm_grid_w + pad_w ) // vit_merger_window_size
607
+ index_padded = F .pad (index , (0 , pad_w , 0 , pad_h ), 'constant' , - 100 )
608
+ index_padded = index_padded .reshape (grid_t , num_windows_h ,
609
+ vit_merger_window_size ,
610
+ num_windows_w ,
611
+ vit_merger_window_size )
612
+ index_padded = index_padded .permute (0 , 1 , 3 , 2 , 4 ).reshape (
613
+ grid_t , num_windows_h * num_windows_w , vit_merger_window_size ,
614
+ vit_merger_window_size )
615
+ seqlens = (index_padded != - 100 ).sum ([2 , 3 ]).reshape (- 1 )
616
+ index_padded = index_padded .reshape (- 1 )
617
+ index_new = index_padded [index_padded != - 100 ]
618
+ cu_seqlens_tmp = seqlens .cumsum (0 ) * self .spatial_merge_unit
619
+ cu_seqlens_tmp = cu_seqlens_tmp .to (dtype = torch .int32 )
620
+ cu_seqlens_tmp = torch .unique_consecutive (cu_seqlens_tmp )
621
+
622
+ return index_new , cu_seqlens_tmp
623
+
624
+ @lru_cache (maxsize = 1024 ) # noqa: B019
625
+ def get_rope_by_thw (self , t , h , w ):
626
+ window_index_thw , cu_seqlens_window_thw = self .get_window_index_thw (
627
+ t , h , w )
628
+ rotary_pos_emb_thw = self .rotary_pos_emb_thw (t , h , w )
629
+ rotary_pos_emb_thw = rotary_pos_emb_thw [window_index_thw , :, :]
630
+ rotary_pos_emb_thw = rotary_pos_emb_thw .flatten (start_dim = 0 , end_dim = 1 )
631
+ cu_seqlens_thw = torch .repeat_interleave (
632
+ torch .tensor ([h * w ], dtype = torch .int32 ), t )
633
+ return (rotary_pos_emb_thw , window_index_thw , cu_seqlens_window_thw ,
634
+ cu_seqlens_thw )
629
635
630
636
def compute_attn_mask_seqlen (
631
637
self ,
@@ -641,45 +647,74 @@ def compute_attn_mask_seqlen(
641
647
def forward (
642
648
self ,
643
649
x : torch .Tensor ,
644
- grid_thw : torch . Tensor ,
650
+ grid_thw : list [ list [ int ]] ,
645
651
) -> torch .Tensor :
646
652
# patchify
653
+ seq_len , _ = x .size ()
654
+ rotary_pos_emb = []
655
+ window_index : list = []
656
+ cu_window_seqlens : list = [torch .tensor ([0 ], dtype = torch .int32 )]
657
+ cu_seqlens : list = []
658
+
647
659
hidden_states = x .to (device = self .device , dtype = self .dtype )
648
660
hidden_states = self .patch_embed (hidden_states )
649
661
650
- # compute position embedding
651
- rotary_pos_emb = self .rot_pos_emb (grid_thw )
662
+ window_index_id = 0
663
+ cu_window_seqlens_last = 0
664
+ for t , h , w in grid_thw :
665
+ t , h , w = int (t ), int (h ), int (w )
666
+ llm_h = h // self .spatial_merge_size
667
+ llm_w = w // self .spatial_merge_size
668
+
669
+ (
670
+ rotary_pos_emb_thw ,
671
+ window_index_thw ,
672
+ cu_seqlens_window_thw ,
673
+ cu_seqlens_thw ,
674
+ ) = self .get_rope_by_thw (t , h , w )
675
+
676
+ window_index .append (window_index_thw + window_index_id )
677
+ window_index_id += (t * llm_h * llm_w )
678
+
679
+ cu_seqlens_window_thw = (cu_seqlens_window_thw +
680
+ cu_window_seqlens_last )
681
+ cu_window_seqlens_last = cu_seqlens_window_thw [- 1 ]
682
+ cu_window_seqlens .append (cu_seqlens_window_thw )
652
683
653
- # windows attention
654
- window_index , cu_window_seqlens = self .get_window_index (grid_thw )
655
- cu_window_seqlens = torch .tensor (
656
- cu_window_seqlens ,
657
- device = hidden_states .device ,
658
- dtype = grid_thw .dtype if torch .jit .is_tracing () else torch .int32 )
684
+ rotary_pos_emb .append (rotary_pos_emb_thw )
685
+
686
+ cu_seqlens .append (cu_seqlens_thw )
687
+
688
+ rotary_pos_emb = torch .cat (rotary_pos_emb )
689
+ window_index = torch .cat (window_index )
690
+ cu_window_seqlens = torch .cat (cu_window_seqlens )
659
691
cu_window_seqlens = torch .unique_consecutive (cu_window_seqlens )
660
- seq_len , _ = hidden_states .size ()
661
- hidden_states = hidden_states .reshape (
662
- seq_len // self .spatial_merge_unit , self .spatial_merge_unit , - 1 )
663
- hidden_states = hidden_states [window_index , :, :]
664
- hidden_states = hidden_states .reshape (seq_len , - 1 )
665
- rotary_pos_emb = rotary_pos_emb .reshape (
666
- seq_len // self .spatial_merge_unit , self .spatial_merge_unit , - 1 )
667
- rotary_pos_emb = rotary_pos_emb [window_index , :, :]
668
- rotary_pos_emb = rotary_pos_emb .reshape (seq_len , - 1 )
669
- # compute cu_seqlens
670
- cu_seqlens = torch .repeat_interleave (grid_thw [:, 1 ] * grid_thw [:, 2 ],
671
- grid_thw [:, 0 ]).cumsum (
672
- dim = 0 , dtype = torch .int32 )
692
+ cu_seqlens = torch .cat (cu_seqlens )
693
+ cu_seqlens = torch .cumsum (cu_seqlens , dim = 0 , dtype = torch .int32 )
673
694
cu_seqlens = F .pad (cu_seqlens , (1 , 0 ), "constant" , 0 )
674
695
675
696
# transformers
676
- hidden_states = hidden_states .unsqueeze (1 )
677
-
678
697
# pre-compute seqlens for window/full attn to reduce cuMemcpy operations
679
698
max_seqlen_full , seqlens_full = self .compute_attn_mask_seqlen (
680
699
cu_seqlens )
681
700
max_seqlen_window , seqlens_window = self .compute_attn_mask_seqlen (
682
701
cu_window_seqlens )
702
+
703
+ cu_seqlens = cu_seqlens .to (device = self .device , non_blocking = True )
704
+ cu_window_seqlens = cu_window_seqlens .to (device = self .device ,
705
+ non_blocking = True )
706
+ rotary_pos_emb = rotary_pos_emb .to (device = self .device ,
707
+ non_blocking = True )
708
+ window_index = window_index .to (device = hidden_states .device ,
709
+ non_blocking = True )
710
+
711
+ hidden_states = hidden_states .reshape (
712
+ seq_len // self .spatial_merge_unit , self .spatial_merge_unit , - 1 )
713
+ hidden_states = hidden_states [window_index , :, :]
714
+ hidden_states = hidden_states .reshape (seq_len , - 1 )
715
+
716
+ hidden_states = hidden_states .unsqueeze (1 )
717
+
683
718
for layer_num , blk in enumerate (self .blocks ):
684
719
if layer_num in self .fullatt_block_indexes :
685
720
cu_seqlens_now = cu_seqlens
@@ -932,12 +967,13 @@ def _process_image_input(
932
967
933
968
grid_thw = image_input ["image_grid_thw" ]
934
969
assert grid_thw .ndim == 2
970
+ grid_thw_list = grid_thw .tolist ()
935
971
936
972
if image_input ["type" ] == "image_embeds" :
937
973
image_embeds = image_input ["image_embeds" ].type (self .visual .dtype )
938
974
else :
939
975
pixel_values = image_input ["pixel_values" ].type (self .visual .dtype )
940
- image_embeds = self .visual (pixel_values , grid_thw = grid_thw )
976
+ image_embeds = self .visual (pixel_values , grid_thw = grid_thw_list )
941
977
942
978
# Split concatenated embeddings for each image item.
943
979
merge_size = self .visual .spatial_merge_size
@@ -951,13 +987,15 @@ def _process_video_input(
951
987
952
988
grid_thw = video_input ["video_grid_thw" ]
953
989
assert grid_thw .ndim == 2
990
+ grid_thw_list = grid_thw .tolist ()
954
991
955
992
if video_input ["type" ] == "video_embeds" :
956
993
video_embeds = video_input ["video_embeds" ].type (self .visual .dtype )
957
994
else :
958
995
pixel_values_videos = video_input ["pixel_values_videos" ].type (
959
996
self .visual .dtype )
960
- video_embeds = self .visual (pixel_values_videos , grid_thw = grid_thw )
997
+ video_embeds = self .visual (pixel_values_videos ,
998
+ grid_thw = grid_thw_list )
961
999
962
1000
# Split concatenated embeddings for each video item.
963
1001
merge_size = self .visual .spatial_merge_size
0 commit comments