Skip to content

Commit 7326644

Browse files
authored
[CI] Fix qwen2.5 vl CI failure (#888)
The [vllm commit](vllm-project/vllm@67da572) changed the input and rotary position embedding for qwen 2.5 vl which break CI. This PR fix the CI failure for qwen2.5 vl in quick Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent df16c4f commit 7326644

File tree

1 file changed

+109
-4
lines changed

1 file changed

+109
-4
lines changed

vllm_ascend/models/qwen2_5_vl.py

Lines changed: 109 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@
3636
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
3737
from vllm.model_executor.models.qwen2_5_vl import (
3838
Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed,
39-
Qwen2_5_VisionTransformer, Qwen2_5_VLDummyInputsBuilder,
40-
Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLMultiModalProcessor,
41-
Qwen2_5_VLProcessingInfo)
39+
Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VisionTransformer,
40+
Qwen2_5_VLDummyInputsBuilder, Qwen2_5_VLForConditionalGeneration,
41+
Qwen2_5_VLMultiModalProcessor, Qwen2_5_VLProcessingInfo)
4242
from vllm.model_executor.models.utils import maybe_prefix
4343
from vllm.multimodal import MULTIMODAL_REGISTRY
4444

@@ -152,6 +152,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
152152
return x
153153

154154

155+
class AscendQwen2_5_VisionRotaryEmbedding(Qwen2_5_VisionRotaryEmbedding):
156+
157+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
158+
super().__init__(dim, theta)
159+
inv_freq = 1.0 / (theta
160+
**(torch.arange(0, dim, 2, dtype=torch.float) / dim))
161+
self.inv_freq = inv_freq
162+
163+
155164
class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
156165

157166
def __init__(
@@ -166,6 +175,9 @@ def __init__(
166175
norm_layer = partial(RMSNorm, eps=norm_eps)
167176
self.interleaved = interleaved
168177
self.enable_pad = False
178+
head_dim = self.hidden_size // self.num_heads
179+
self.rotary_pos_emb = AscendQwen2_5_VisionRotaryEmbedding(head_dim //
180+
2)
169181
self.patch_embed = AscendQwen2_5_VisionPatchEmbed(
170182
patch_size=vision_config.patch_size,
171183
temporal_patch_size=vision_config.temporal_patch_size,
@@ -298,6 +310,66 @@ def load_weights(self, weights: Iterable[Tuple[str,
298310
loaded_params.add(name)
299311
return loaded_params
300312

313+
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
314+
pos_ids = []
315+
for t, h, w in grid_thw:
316+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
317+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
318+
hpos_ids = hpos_ids.reshape(
319+
h // self.spatial_merge_size,
320+
self.spatial_merge_size,
321+
w // self.spatial_merge_size,
322+
self.spatial_merge_size,
323+
).permute(0, 2, 1, 3).flatten()
324+
wpos_ids = wpos_ids.reshape(
325+
h // self.spatial_merge_size,
326+
self.spatial_merge_size,
327+
w // self.spatial_merge_size,
328+
self.spatial_merge_size,
329+
).permute(0, 2, 1, 3).flatten()
330+
pos_ids.append(
331+
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
332+
pos_ids = torch.cat(pos_ids, dim=0)
333+
max_grid_size = grid_thw[:, 1:].max()
334+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
335+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
336+
return rotary_pos_emb
337+
338+
def get_window_index(self, grid_thw):
339+
window_index: list = []
340+
cu_window_seqlens: list = [0]
341+
window_index_id = 0
342+
vit_merger_window_size = (self.window_size //
343+
self.spatial_merge_size // self.patch_size)
344+
345+
for grid_t, grid_h, grid_w in grid_thw:
346+
llm_grid_h = grid_h // self.spatial_merge_size
347+
llm_grid_w = grid_w // self.spatial_merge_size
348+
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
349+
grid_t, llm_grid_h, llm_grid_w)
350+
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
351+
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
352+
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
353+
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
354+
index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100)
355+
index_padded = index_padded.reshape(grid_t, num_windows_h,
356+
vit_merger_window_size,
357+
num_windows_w,
358+
vit_merger_window_size)
359+
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
360+
grid_t, num_windows_h * num_windows_w, vit_merger_window_size,
361+
vit_merger_window_size)
362+
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
363+
index_padded = index_padded.reshape(-1)
364+
index_new = index_padded[index_padded != -100]
365+
window_index.append(index_new + window_index_id)
366+
cu_seqlens_tmp = seqlens.cumsum(
367+
0) * self.spatial_merge_unit + cu_window_seqlens[-1]
368+
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
369+
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
370+
window_index = torch.cat(window_index, dim=0)
371+
return window_index, cu_window_seqlens
372+
301373
def forward(
302374
self,
303375
x: torch.Tensor,
@@ -366,4 +438,37 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
366438
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
367439
quant_config=self._maybe_ignore_quant_config(quant_config),
368440
prefix=maybe_prefix(prefix, "visual"),
369-
)
441+
)
442+
443+
def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]:
444+
445+
grid_thw = image_input["image_grid_thw"]
446+
assert grid_thw.ndim == 2
447+
448+
if image_input["type"] == "image_embeds":
449+
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
450+
else:
451+
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
452+
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
453+
454+
# Split concatenated embeddings for each image item.
455+
merge_size = self.visual.spatial_merge_size
456+
sizes = grid_thw.prod(-1) // merge_size // merge_size
457+
return image_embeds.split(sizes.tolist())
458+
459+
def _process_video_input(self, video_input) -> tuple[torch.Tensor, ...]:
460+
461+
grid_thw = video_input["video_grid_thw"]
462+
assert grid_thw.ndim == 2
463+
464+
if video_input["type"] == "video_embeds":
465+
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
466+
else:
467+
pixel_values_videos = video_input["pixel_values_videos"].type(
468+
self.visual.dtype)
469+
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
470+
471+
# Split concatenated embeddings for each video item.
472+
merge_size = self.visual.spatial_merge_size
473+
sizes = grid_thw.prod(-1) // merge_size // merge_size
474+
return video_embeds.split(sizes.tolist())

0 commit comments

Comments
 (0)