Skip to content

Commit 67da572

Browse files
authored
[PERF] Speed up Qwen2.5-VL model by speed up rotary position embedding (#17973)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@centml.ai>
1 parent 5c04bb8 commit 67da572

File tree

1 file changed

+121
-83
lines changed

1 file changed

+121
-83
lines changed

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 121 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# limitations under the License.
2626
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
2727
from collections.abc import Iterable, Mapping
28-
from functools import partial
28+
from functools import lru_cache, partial
2929
from typing import Callable, Literal, Optional, TypedDict, Union
3030

3131
import torch
@@ -478,8 +478,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None:
478478
super().__init__()
479479
self.dim = dim
480480
self.theta = theta
481-
inv_freq = 1.0 / (theta
482-
**(torch.arange(0, dim, 2, dtype=torch.float) / dim))
481+
inv_freq = 1.0 / (theta**(
482+
torch.arange(0, dim, 2, dtype=torch.float, device='cpu') / dim))
483483
self.register_buffer("inv_freq", inv_freq, persistent=False)
484484
self._seq_len_cached = 0
485485
self._freqs_cached = None
@@ -520,7 +520,7 @@ def __init__(
520520
self.hidden_size = vision_config.hidden_size
521521
self.num_heads = vision_config.num_heads
522522

523-
# args for get_window_index
523+
# args for get_window_index_thw
524524
self.window_size = vision_config.window_size
525525
self.patch_size = vision_config.patch_size
526526
self.spatial_merge_size = vision_config.spatial_merge_size
@@ -567,65 +567,71 @@ def dtype(self) -> torch.dtype:
567567
def device(self) -> torch.device:
568568
return self.patch_embed.proj.weight.device
569569

570-
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
571-
pos_ids = []
572-
for t, h, w in grid_thw:
573-
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
574-
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
575-
hpos_ids = hpos_ids.reshape(
576-
h // self.spatial_merge_size,
577-
self.spatial_merge_size,
578-
w // self.spatial_merge_size,
579-
self.spatial_merge_size,
580-
).permute(0, 2, 1, 3).flatten()
581-
wpos_ids = wpos_ids.reshape(
582-
h // self.spatial_merge_size,
583-
self.spatial_merge_size,
584-
w // self.spatial_merge_size,
585-
self.spatial_merge_size,
586-
).permute(0, 2, 1, 3).flatten()
587-
pos_ids.append(
588-
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
589-
pos_ids = torch.cat(pos_ids, dim=0)
590-
max_grid_size = grid_thw[:, 1:].max()
591-
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
570+
def rotary_pos_emb_thw(self, t, h, w):
571+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
572+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
573+
hpos_ids = hpos_ids.reshape(
574+
h // self.spatial_merge_size,
575+
self.spatial_merge_size,
576+
w // self.spatial_merge_size,
577+
self.spatial_merge_size,
578+
).permute(0, 2, 1, 3).flatten()
579+
wpos_ids = wpos_ids.reshape(
580+
h // self.spatial_merge_size,
581+
self.spatial_merge_size,
582+
w // self.spatial_merge_size,
583+
self.spatial_merge_size,
584+
).permute(0, 2, 1, 3).flatten()
585+
pos_ids = torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)
586+
max_size = max(h, w)
587+
rotary_pos_emb_full = self.rotary_pos_emb(max_size)
592588
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
589+
rotary_pos_emb = rotary_pos_emb.reshape(
590+
rotary_pos_emb.shape[0] // self.spatial_merge_unit,
591+
self.spatial_merge_unit, -1)
592+
593593
return rotary_pos_emb
594594

595-
def get_window_index(self, grid_thw):
596-
window_index: list = []
597-
cu_window_seqlens: list = [0]
598-
window_index_id = 0
595+
def get_window_index_thw(self, grid_t, grid_h, grid_w):
599596
vit_merger_window_size = (self.window_size //
600597
self.spatial_merge_size // self.patch_size)
601598

602-
for grid_t, grid_h, grid_w in grid_thw:
603-
llm_grid_h = grid_h // self.spatial_merge_size
604-
llm_grid_w = grid_w // self.spatial_merge_size
605-
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
606-
grid_t, llm_grid_h, llm_grid_w)
607-
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
608-
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
609-
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
610-
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
611-
index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100)
612-
index_padded = index_padded.reshape(grid_t, num_windows_h,
613-
vit_merger_window_size,
614-
num_windows_w,
615-
vit_merger_window_size)
616-
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
617-
grid_t, num_windows_h * num_windows_w, vit_merger_window_size,
618-
vit_merger_window_size)
619-
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
620-
index_padded = index_padded.reshape(-1)
621-
index_new = index_padded[index_padded != -100]
622-
window_index.append(index_new + window_index_id)
623-
cu_seqlens_tmp = seqlens.cumsum(
624-
0) * self.spatial_merge_unit + cu_window_seqlens[-1]
625-
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
626-
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
627-
window_index = torch.cat(window_index, dim=0)
628-
return window_index, cu_window_seqlens
599+
llm_grid_h = grid_h // self.spatial_merge_size
600+
llm_grid_w = grid_w // self.spatial_merge_size
601+
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
602+
grid_t, llm_grid_h, llm_grid_w)
603+
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
604+
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
605+
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
606+
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
607+
index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100)
608+
index_padded = index_padded.reshape(grid_t, num_windows_h,
609+
vit_merger_window_size,
610+
num_windows_w,
611+
vit_merger_window_size)
612+
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
613+
grid_t, num_windows_h * num_windows_w, vit_merger_window_size,
614+
vit_merger_window_size)
615+
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
616+
index_padded = index_padded.reshape(-1)
617+
index_new = index_padded[index_padded != -100]
618+
cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit
619+
cu_seqlens_tmp = cu_seqlens_tmp.to(dtype=torch.int32)
620+
cu_seqlens_tmp = torch.unique_consecutive(cu_seqlens_tmp)
621+
622+
return index_new, cu_seqlens_tmp
623+
624+
@lru_cache(maxsize=1024) # noqa: B019
625+
def get_rope_by_thw(self, t, h, w):
626+
window_index_thw, cu_seqlens_window_thw = self.get_window_index_thw(
627+
t, h, w)
628+
rotary_pos_emb_thw = self.rotary_pos_emb_thw(t, h, w)
629+
rotary_pos_emb_thw = rotary_pos_emb_thw[window_index_thw, :, :]
630+
rotary_pos_emb_thw = rotary_pos_emb_thw.flatten(start_dim=0, end_dim=1)
631+
cu_seqlens_thw = torch.repeat_interleave(
632+
torch.tensor([h * w], dtype=torch.int32), t)
633+
return (rotary_pos_emb_thw, window_index_thw, cu_seqlens_window_thw,
634+
cu_seqlens_thw)
629635

630636
def compute_attn_mask_seqlen(
631637
self,
@@ -641,45 +647,74 @@ def compute_attn_mask_seqlen(
641647
def forward(
642648
self,
643649
x: torch.Tensor,
644-
grid_thw: torch.Tensor,
650+
grid_thw: list[list[int]],
645651
) -> torch.Tensor:
646652
# patchify
653+
seq_len, _ = x.size()
654+
rotary_pos_emb = []
655+
window_index: list = []
656+
cu_window_seqlens: list = [torch.tensor([0], dtype=torch.int32)]
657+
cu_seqlens: list = []
658+
647659
hidden_states = x.to(device=self.device, dtype=self.dtype)
648660
hidden_states = self.patch_embed(hidden_states)
649661

650-
# compute position embedding
651-
rotary_pos_emb = self.rot_pos_emb(grid_thw)
662+
window_index_id = 0
663+
cu_window_seqlens_last = 0
664+
for t, h, w in grid_thw:
665+
t, h, w = int(t), int(h), int(w)
666+
llm_h = h // self.spatial_merge_size
667+
llm_w = w // self.spatial_merge_size
668+
669+
(
670+
rotary_pos_emb_thw,
671+
window_index_thw,
672+
cu_seqlens_window_thw,
673+
cu_seqlens_thw,
674+
) = self.get_rope_by_thw(t, h, w)
675+
676+
window_index.append(window_index_thw + window_index_id)
677+
window_index_id += (t * llm_h * llm_w)
678+
679+
cu_seqlens_window_thw = (cu_seqlens_window_thw +
680+
cu_window_seqlens_last)
681+
cu_window_seqlens_last = cu_seqlens_window_thw[-1]
682+
cu_window_seqlens.append(cu_seqlens_window_thw)
652683

653-
# windows attention
654-
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
655-
cu_window_seqlens = torch.tensor(
656-
cu_window_seqlens,
657-
device=hidden_states.device,
658-
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32)
684+
rotary_pos_emb.append(rotary_pos_emb_thw)
685+
686+
cu_seqlens.append(cu_seqlens_thw)
687+
688+
rotary_pos_emb = torch.cat(rotary_pos_emb)
689+
window_index = torch.cat(window_index)
690+
cu_window_seqlens = torch.cat(cu_window_seqlens)
659691
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
660-
seq_len, _ = hidden_states.size()
661-
hidden_states = hidden_states.reshape(
662-
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
663-
hidden_states = hidden_states[window_index, :, :]
664-
hidden_states = hidden_states.reshape(seq_len, -1)
665-
rotary_pos_emb = rotary_pos_emb.reshape(
666-
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
667-
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
668-
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
669-
# compute cu_seqlens
670-
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
671-
grid_thw[:, 0]).cumsum(
672-
dim=0, dtype=torch.int32)
692+
cu_seqlens = torch.cat(cu_seqlens)
693+
cu_seqlens = torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32)
673694
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
674695

675696
# transformers
676-
hidden_states = hidden_states.unsqueeze(1)
677-
678697
# pre-compute seqlens for window/full attn to reduce cuMemcpy operations
679698
max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen(
680699
cu_seqlens)
681700
max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen(
682701
cu_window_seqlens)
702+
703+
cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True)
704+
cu_window_seqlens = cu_window_seqlens.to(device=self.device,
705+
non_blocking=True)
706+
rotary_pos_emb = rotary_pos_emb.to(device=self.device,
707+
non_blocking=True)
708+
window_index = window_index.to(device=hidden_states.device,
709+
non_blocking=True)
710+
711+
hidden_states = hidden_states.reshape(
712+
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
713+
hidden_states = hidden_states[window_index, :, :]
714+
hidden_states = hidden_states.reshape(seq_len, -1)
715+
716+
hidden_states = hidden_states.unsqueeze(1)
717+
683718
for layer_num, blk in enumerate(self.blocks):
684719
if layer_num in self.fullatt_block_indexes:
685720
cu_seqlens_now = cu_seqlens
@@ -932,12 +967,13 @@ def _process_image_input(
932967

933968
grid_thw = image_input["image_grid_thw"]
934969
assert grid_thw.ndim == 2
970+
grid_thw_list = grid_thw.tolist()
935971

936972
if image_input["type"] == "image_embeds":
937973
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
938974
else:
939975
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
940-
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
976+
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
941977

942978
# Split concatenated embeddings for each image item.
943979
merge_size = self.visual.spatial_merge_size
@@ -951,13 +987,15 @@ def _process_video_input(
951987

952988
grid_thw = video_input["video_grid_thw"]
953989
assert grid_thw.ndim == 2
990+
grid_thw_list = grid_thw.tolist()
954991

955992
if video_input["type"] == "video_embeds":
956993
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
957994
else:
958995
pixel_values_videos = video_input["pixel_values_videos"].type(
959996
self.visual.dtype)
960-
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
997+
video_embeds = self.visual(pixel_values_videos,
998+
grid_thw=grid_thw_list)
961999

9621000
# Split concatenated embeddings for each video item.
9631001
merge_size = self.visual.spatial_merge_size

0 commit comments

Comments
 (0)