Skip to content

Commit 1e56aae

Browse files
authored
[0.7.3] optimize Qwen2.5 vl vit (#623)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? optimize Qwen2 5 vl vit with pta <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? no <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? we've tested on benchmark and it proves to be equal. <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> --------- Signed-off-by: zouyida <zouyida@huawei.com>
1 parent 4c41672 commit 1e56aae

File tree

4 files changed

+399
-2
lines changed

4 files changed

+399
-2
lines changed

vllm_ascend/models/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ def register_model():
66
"Qwen2VLForConditionalGeneration",
77
"vllm_ascend.models.qwen2_vl:CustomQwen2VLForConditionalGeneration")
88

9+
ModelRegistry.register_model(
10+
"Qwen2_5_VLForConditionalGeneration",
11+
"vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration"
12+
)
13+
914
ModelRegistry.register_model(
1015
"DeepseekV2ForCausalLM",
1116
"vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM")

vllm_ascend/models/qwen2_5_vl.py

Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Adapted from vllm/model_executor/models/qwen2_5_vl.py
4+
# Copyright 2023 The vLLM team.
5+
#
6+
# This file is a part of the vllm-ascend project.
7+
#
8+
# Licensed under the Apache License, Version 2.0 (the "License");
9+
# you may not use this file except in compliance with the License.
10+
# You may obtain a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
20+
from functools import partial
21+
from typing import Callable, Iterable, Optional, Set, Tuple
22+
23+
import torch
24+
import torch.nn as nn
25+
import torch.nn.functional as F
26+
import torch_npu
27+
from einops import rearrange
28+
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
29+
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
30+
from vllm.config import VllmConfig
31+
from vllm.distributed import parallel_state
32+
from vllm.distributed import utils as dist_utils
33+
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
34+
from vllm.model_executor.layers.layernorm import RMSNorm
35+
from vllm.model_executor.layers.quantization import QuantizationConfig
36+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
37+
from vllm.model_executor.models.qwen2_5_vl import (
38+
Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed,
39+
Qwen2_5_VisionTransformer, Qwen2_5_VLDummyInputsBuilder,
40+
Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLMultiModalProcessor,
41+
Qwen2_5_VLProcessingInfo)
42+
from vllm.model_executor.models.utils import maybe_prefix
43+
from vllm.multimodal import MULTIMODAL_REGISTRY
44+
45+
MIN_PAD_SIZE = 64
46+
MAX_PAD_SIZE = 128
47+
48+
49+
class AscendQwen2_5_VisionAttention(Qwen2_5_VisionAttention):
50+
51+
def __init__(
52+
self,
53+
embed_dim: int,
54+
num_heads: int,
55+
projection_size: int,
56+
quant_config: Optional[QuantizationConfig] = None,
57+
prefix: str = "",
58+
) -> None:
59+
super().__init__(
60+
embed_dim,
61+
num_heads,
62+
projection_size,
63+
quant_config,
64+
prefix,
65+
)
66+
self.embed_dim = embed_dim
67+
self.hidden_size_per_attention_head = dist_utils.divide(
68+
projection_size, num_heads)
69+
if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE:
70+
self.hidden_size_per_attention_head = MAX_PAD_SIZE
71+
72+
def forward(
73+
self,
74+
x: torch.Tensor,
75+
cu_seqlens: torch.Tensor,
76+
cos: torch.Tensor,
77+
sin: torch.Tensor,
78+
) -> torch.Tensor:
79+
# [s, b, c] --> [s, b, head * 3 * head_dim]
80+
x, _ = self.qkv(x)
81+
82+
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
83+
q, k, v = self.split_qkv(x)
84+
batch_size = q.shape[1]
85+
86+
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
87+
for x in (q, k, v))
88+
q = torch_npu.npu_rotary_mul(q, cos, sin)
89+
k = torch_npu.npu_rotary_mul(k, cos, sin)
90+
91+
q, k, v = [
92+
rearrange(x, "b s h d -> (b s) h d").contiguous()
93+
for x in (q, k, v)
94+
]
95+
96+
context_layer = torch.torch.empty_like(q)
97+
98+
# operator requires pta version >= 2.5.1.dev20250226
99+
torch_npu._npu_flash_attention_unpad(
100+
query=q,
101+
key=k,
102+
value=v,
103+
seq_len=cu_seqlens,
104+
scale_value=self.hidden_size_per_attention_head**-0.5,
105+
num_heads=self.num_attention_heads_per_partition,
106+
num_kv_heads=self.num_attention_heads_per_partition,
107+
out=context_layer)
108+
109+
context_layer = rearrange(context_layer,
110+
"(b s) h d -> s b (h d)",
111+
b=batch_size).contiguous()
112+
113+
output, _ = self.proj(context_layer)
114+
return output
115+
116+
117+
class AscendQwen2_5_VisionBlock(Qwen2_5_VisionBlock):
118+
119+
def __init__(
120+
self,
121+
dim: int,
122+
num_heads: int,
123+
mlp_hidden_dim: int,
124+
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
125+
norm_layer: Optional[Callable[[int], nn.Module]] = None,
126+
quant_config: Optional[QuantizationConfig] = None,
127+
prefix: str = "",
128+
) -> None:
129+
super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer,
130+
quant_config, prefix)
131+
self.attn = AscendQwen2_5_VisionAttention(embed_dim=dim,
132+
num_heads=num_heads,
133+
projection_size=dim,
134+
quant_config=quant_config,
135+
prefix=f"{prefix}.attn")
136+
137+
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
138+
cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
139+
x = x + self.attn(
140+
self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin)
141+
142+
x = x + self.mlp(self.norm2(x))
143+
return x
144+
145+
146+
class AscendQwen2_5_VisionPatchEmbed(Qwen2_5_VisionPatchEmbed):
147+
148+
def forward(self, x: torch.Tensor) -> torch.Tensor:
149+
x = x.matmul(
150+
self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1))
151+
return x
152+
153+
154+
class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
155+
156+
def __init__(
157+
self,
158+
vision_config: Qwen2_5_VLVisionConfig,
159+
norm_eps: float = 1e-6,
160+
quant_config: Optional[QuantizationConfig] = None,
161+
prefix: str = "",
162+
interleaved=False,
163+
) -> None:
164+
super().__init__(vision_config, norm_eps, quant_config, prefix)
165+
norm_layer = partial(RMSNorm, eps=norm_eps)
166+
self.interleaved = interleaved
167+
self.patch_embed = AscendQwen2_5_VisionPatchEmbed(
168+
patch_size=vision_config.patch_size,
169+
temporal_patch_size=vision_config.temporal_patch_size,
170+
in_channels=vision_config.in_channels,
171+
hidden_size=self.hidden_size,
172+
)
173+
self.blocks = nn.ModuleList([
174+
AscendQwen2_5_VisionBlock(
175+
dim=self.hidden_size,
176+
num_heads=self.num_heads,
177+
mlp_hidden_dim=vision_config.intermediate_size,
178+
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
179+
norm_layer=norm_layer,
180+
quant_config=quant_config,
181+
prefix=f"{prefix}.blocks.{layer_idx}")
182+
for layer_idx in range(vision_config.depth)
183+
])
184+
self.tp_size = parallel_state.get_tensor_model_parallel_world_size()
185+
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
186+
self.hidden_size_per_attention_head = dist_utils.divide(
187+
self.hidden_size, self.num_heads)
188+
189+
if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE:
190+
self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head
191+
self.half_origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head // 2
192+
self.half_pad_hidden_size_per_attention_head = (
193+
MAX_PAD_SIZE - self.hidden_size_per_attention_head) // 2
194+
self.hidden_size_per_attention_head = MAX_PAD_SIZE
195+
196+
def cal_cos_sin(self, rotary_pos_emb):
197+
cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2]
198+
sin = rotary_pos_emb.sin()
199+
cos = torch.nn.functional.pad(
200+
cos, (0, self.half_pad_hidden_size_per_attention_head))
201+
sin = torch.nn.functional.pad(
202+
sin, (0, self.half_pad_hidden_size_per_attention_head))
203+
204+
if not self.interleaved:
205+
cos_new = torch.cat((cos, cos), dim=-1)
206+
sin_new = torch.cat((sin, sin), dim=-1)
207+
else:
208+
cos_new = rearrange(torch.stack((cos, cos), dim=-1),
209+
"... d two -> ...(d two)",
210+
two=2)
211+
sin_new = rearrange(torch.stack((sin, sin), dim=-1),
212+
"... d two -> ...(d two)",
213+
two=2)
214+
cos_new = cos_new.reshape(1, -1, 1,
215+
self.hidden_size_per_attention_head)
216+
sin_new = sin_new.reshape(1, -1, 1,
217+
self.hidden_size_per_attention_head)
218+
return cos_new, sin_new
219+
220+
def pad_qkv_bias(self, bias):
221+
first_half = bias.reshape(
222+
-1, 3, self.origin_hidden_size_per_attention_head
223+
)[:, :, :self.half_origin_hidden_size_per_attention_head]
224+
second_half = bias.reshape(
225+
-1, 3, self.origin_hidden_size_per_attention_head
226+
)[:, :, self.half_origin_hidden_size_per_attention_head:]
227+
first_half_padded = torch.nn.functional.pad(
228+
first_half, (0, self.half_pad_hidden_size_per_attention_head))
229+
second_half_padded = torch.nn.functional.pad(
230+
second_half, (0, self.half_pad_hidden_size_per_attention_head))
231+
bias_padded = torch.cat([first_half_padded, second_half_padded], dim=2)
232+
bias_final = bias_padded.reshape(-1)
233+
return bias_final
234+
235+
def pad_qkv_weight(self, data):
236+
qkv_weight_first_half = data.reshape(
237+
-1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size
238+
)[:, :, :self.half_origin_hidden_size_per_attention_head, :]
239+
qkv_weight_second_half = data.reshape(
240+
-1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size
241+
)[:, :, self.half_origin_hidden_size_per_attention_head:, :]
242+
243+
qkv_weight_first_half_padded = torch.nn.functional.pad(
244+
qkv_weight_first_half,
245+
(0, 0, 0, self.half_pad_hidden_size_per_attention_head))
246+
qkv_weight_second_half_padded = torch.nn.functional.pad(
247+
qkv_weight_second_half,
248+
(0, 0, 0, self.half_pad_hidden_size_per_attention_head))
249+
qkv_weight_padded = torch.cat(
250+
[qkv_weight_first_half_padded, qkv_weight_second_half_padded],
251+
dim=2)
252+
qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size)
253+
return qkv_weight_final
254+
255+
def pad_proj_weight(self, data):
256+
out_weight = torch.nn.functional.pad(
257+
data.reshape(self.hidden_size, -1,
258+
self.half_origin_hidden_size_per_attention_head),
259+
(0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape(
260+
self.hidden_size, -1)
261+
return out_weight
262+
263+
def load_weights(self, weights: Iterable[Tuple[str,
264+
torch.Tensor]]) -> Set[str]:
265+
stacked_params_mapping = [
266+
# (param_name, shard_name, shard_id)
267+
("qkv_proj", "q_proj", "q"),
268+
("qkv_proj", "k_proj", "k"),
269+
("qkv_proj", "v_proj", "v"),
270+
]
271+
params_dict = dict(self.named_parameters(remove_duplicate=False))
272+
loaded_params: Set[str] = set()
273+
for name, loaded_weight in weights:
274+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
275+
if weight_name not in name:
276+
continue
277+
name = name.replace(weight_name, param_name)
278+
279+
param = params_dict[name]
280+
weight_loader = param.weight_loader
281+
weight_loader(param, loaded_weight, shard_id)
282+
break
283+
else:
284+
param = params_dict[name]
285+
weight_loader = getattr(param, "weight_loader",
286+
default_weight_loader)
287+
weight_loader(param, loaded_weight)
288+
if ("attn.proj.weight" in name):
289+
param.data = self.pad_proj_weight(param.data)
290+
if ("attn.qkv.weight" in name):
291+
param.data = self.pad_qkv_weight(param.data)
292+
if ("attn.qkv.bias" in name):
293+
param.data = self.pad_qkv_bias(param.data)
294+
loaded_params.add(name)
295+
return loaded_params
296+
297+
def forward(
298+
self,
299+
x: torch.Tensor,
300+
grid_thw: torch.Tensor,
301+
) -> torch.Tensor:
302+
# compute cu_seqlens
303+
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
304+
grid_thw[:,
305+
0]).cpu().to(torch.int32)
306+
307+
# patchify
308+
x = self.patch_embed(x)
309+
310+
# compute position embedding
311+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
312+
313+
# windows attention
314+
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
315+
cu_window_seqlens = torch.tensor(
316+
cu_window_seqlens,
317+
device=x.device,
318+
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32)
319+
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
320+
cu_window_seqlens = torch.diff(cu_window_seqlens).cpu().to(torch.int32)
321+
seq_len, _ = x.size()
322+
x = x.reshape(seq_len // self.spatial_merge_unit,
323+
self.spatial_merge_unit, -1)
324+
x = x[window_index, :, :]
325+
x = x.reshape(seq_len, -1)
326+
rotary_pos_emb = rotary_pos_emb.reshape(
327+
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
328+
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
329+
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
330+
331+
cos, sin = self.cal_cos_sin(rotary_pos_emb)
332+
333+
# transformers
334+
x = x.unsqueeze(1)
335+
for layer_num, blk in enumerate(self.blocks):
336+
if layer_num in self.fullatt_block_indexes:
337+
cu_seqlens_now = cu_seqlens
338+
else:
339+
cu_seqlens_now = cu_window_seqlens
340+
x = blk(x, cu_seqlens=cu_seqlens_now, cos=cos, sin=sin)
341+
342+
# adapter
343+
x = self.merger(x)
344+
reverse_indices = torch.argsort(window_index)
345+
x = x[reverse_indices, :]
346+
return x
347+
348+
349+
@MULTIMODAL_REGISTRY.register_processor(
350+
Qwen2_5_VLMultiModalProcessor,
351+
info=Qwen2_5_VLProcessingInfo,
352+
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
353+
class AscendQwen2_5_VLForConditionalGeneration(
354+
Qwen2_5_VLForConditionalGeneration):
355+
356+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
357+
super().__init__(vllm_config=vllm_config, prefix=prefix)
358+
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
359+
quant_config = vllm_config.quant_config
360+
self.visual = AscendQwen2_5_VisionTransformer(
361+
vision_config=config.vision_config,
362+
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
363+
quant_config=self._maybe_ignore_quant_config(quant_config),
364+
prefix=maybe_prefix(prefix, "visual"),
365+
)

0 commit comments

Comments
 (0)