Skip to content

Commit ba9714c

Browse files
Optimize qwen2_vl and qwen2_5_vl (#701)
### What this PR does / why we need it? Optimize qwen2_vl and qwen2_5_vl. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? Testing this PR on 1080p picture with tp=1, bs=1 on Qwen2-VL and Qwen2.5-VL, every fa op's during time lasting from 11ms to 9ms, got roughly 22% perf boost. --------- Signed-off-by: zouyida2052 <zouyida@huawei.com> Signed-off-by: zouyida2052 <zouyida2002@gmail.com> Co-authored-by: zouyida2052 <zouyida@huawei.com>
1 parent 90aabae commit ba9714c

File tree

4 files changed

+559
-27
lines changed

4 files changed

+559
-27
lines changed

vllm_ascend/models/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,22 @@ def register_model():
55
from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401
66
from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401
77
from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401
8-
from .qwen2_vl import CustomQwen2VLForConditionalGeneration # noqa: F401
8+
from .qwen2_5_vl import \
9+
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
10+
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
911

1012
ModelRegistry.register_model(
1113
"DeepSeekMTPModel",
1214
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
1315

1416
ModelRegistry.register_model(
1517
"Qwen2VLForConditionalGeneration",
16-
"vllm_ascend.models.qwen2_vl:CustomQwen2VLForConditionalGeneration")
18+
"vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration")
19+
20+
ModelRegistry.register_model(
21+
"Qwen2_5_VLForConditionalGeneration",
22+
"vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration"
23+
)
1724

1825
ModelRegistry.register_model(
1926
"DeepseekV2ForCausalLM",

vllm_ascend/models/qwen2_5_vl.py

Lines changed: 369 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,369 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Adapted from vllm/model_executor/models/qwen2_5_vl.py
4+
# Copyright 2023 The vLLM team.
5+
#
6+
# This file is a part of the vllm-ascend project.
7+
#
8+
# Licensed under the Apache License, Version 2.0 (the "License");
9+
# you may not use this file except in compliance with the License.
10+
# You may obtain a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
20+
from functools import partial
21+
from typing import Callable, Iterable, Optional, Set, Tuple
22+
23+
import torch
24+
import torch.nn as nn
25+
import torch.nn.functional as F
26+
import torch_npu
27+
from einops import rearrange
28+
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
29+
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
30+
from vllm.config import VllmConfig
31+
from vllm.distributed import parallel_state
32+
from vllm.distributed import utils as dist_utils
33+
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
34+
from vllm.model_executor.layers.layernorm import RMSNorm
35+
from vllm.model_executor.layers.quantization import QuantizationConfig
36+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
37+
from vllm.model_executor.models.qwen2_5_vl import (
38+
Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed,
39+
Qwen2_5_VisionTransformer, Qwen2_5_VLDummyInputsBuilder,
40+
Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLMultiModalProcessor,
41+
Qwen2_5_VLProcessingInfo)
42+
from vllm.model_executor.models.utils import maybe_prefix
43+
from vllm.multimodal import MULTIMODAL_REGISTRY
44+
45+
MIN_PAD_SIZE = 64 # min_size to pad weight
46+
MAX_PAD_SIZE = 128 # max_size to pad weight
47+
48+
49+
class AscendQwen2_5_VisionAttention(Qwen2_5_VisionAttention):
50+
51+
def __init__(
52+
self,
53+
embed_dim: int,
54+
num_heads: int,
55+
projection_size: int,
56+
quant_config: Optional[QuantizationConfig] = None,
57+
prefix: str = "",
58+
) -> None:
59+
super().__init__(
60+
embed_dim,
61+
num_heads,
62+
projection_size,
63+
quant_config,
64+
prefix,
65+
)
66+
self.embed_dim = embed_dim
67+
self.hidden_size_per_attention_head = dist_utils.divide(
68+
projection_size, num_heads)
69+
self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head
70+
if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE:
71+
self.hidden_size_per_attention_head = MAX_PAD_SIZE
72+
73+
def forward(
74+
self,
75+
x: torch.Tensor,
76+
cu_seqlens: torch.Tensor,
77+
cos: torch.Tensor,
78+
sin: torch.Tensor,
79+
) -> torch.Tensor:
80+
# [s, b, c] --> [s, b, head * 3 * head_dim]
81+
x, _ = self.qkv(x)
82+
83+
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
84+
q, k, v = self.split_qkv(x)
85+
batch_size = q.shape[1]
86+
87+
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
88+
for x in (q, k, v))
89+
q = torch_npu.npu_rotary_mul(q, cos, sin)
90+
k = torch_npu.npu_rotary_mul(k, cos, sin)
91+
92+
q, k, v = [
93+
rearrange(x, "b s h d -> (b s) h d").contiguous()
94+
for x in (q, k, v)
95+
]
96+
97+
context_layer = torch.torch.empty_like(q)
98+
99+
# operator requires pta version >= 2.5.1
100+
torch_npu._npu_flash_attention_unpad(
101+
query=q,
102+
key=k,
103+
value=v,
104+
seq_len=cu_seqlens,
105+
scale_value=self.origin_hidden_size_per_attention_head**-0.5,
106+
num_heads=self.num_attention_heads_per_partition,
107+
num_kv_heads=self.num_attention_heads_per_partition,
108+
out=context_layer)
109+
110+
context_layer = rearrange(context_layer,
111+
"(b s) h d -> s b (h d)",
112+
b=batch_size).contiguous()
113+
114+
output, _ = self.proj(context_layer)
115+
return output
116+
117+
118+
class AscendQwen2_5_VisionBlock(Qwen2_5_VisionBlock):
119+
120+
def __init__(
121+
self,
122+
dim: int,
123+
num_heads: int,
124+
mlp_hidden_dim: int,
125+
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
126+
norm_layer: Optional[Callable[[int], nn.Module]] = None,
127+
quant_config: Optional[QuantizationConfig] = None,
128+
prefix: str = "",
129+
) -> None:
130+
super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer,
131+
quant_config, prefix)
132+
self.attn = AscendQwen2_5_VisionAttention(embed_dim=dim,
133+
num_heads=num_heads,
134+
projection_size=dim,
135+
quant_config=quant_config,
136+
prefix=f"{prefix}.attn")
137+
138+
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
139+
cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
140+
x = x + self.attn(
141+
self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin)
142+
143+
x = x + self.mlp(self.norm2(x))
144+
return x
145+
146+
147+
class AscendQwen2_5_VisionPatchEmbed(Qwen2_5_VisionPatchEmbed):
148+
149+
def forward(self, x: torch.Tensor) -> torch.Tensor:
150+
x = x.matmul(
151+
self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1))
152+
return x
153+
154+
155+
class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
156+
157+
def __init__(
158+
self,
159+
vision_config: Qwen2_5_VLVisionConfig,
160+
norm_eps: float = 1e-6,
161+
quant_config: Optional[QuantizationConfig] = None,
162+
prefix: str = "",
163+
interleaved=False,
164+
) -> None:
165+
super().__init__(vision_config, norm_eps, quant_config, prefix)
166+
norm_layer = partial(RMSNorm, eps=norm_eps)
167+
self.interleaved = interleaved
168+
self.enable_pad = False
169+
self.patch_embed = AscendQwen2_5_VisionPatchEmbed(
170+
patch_size=vision_config.patch_size,
171+
temporal_patch_size=vision_config.temporal_patch_size,
172+
in_channels=vision_config.in_channels,
173+
hidden_size=self.hidden_size,
174+
)
175+
self.blocks = nn.ModuleList([
176+
AscendQwen2_5_VisionBlock(
177+
dim=self.hidden_size,
178+
num_heads=self.num_heads,
179+
mlp_hidden_dim=vision_config.intermediate_size,
180+
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
181+
norm_layer=norm_layer,
182+
quant_config=quant_config,
183+
prefix=f"{prefix}.blocks.{layer_idx}")
184+
for layer_idx in range(vision_config.depth)
185+
])
186+
self.tp_size = parallel_state.get_tensor_model_parallel_world_size()
187+
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
188+
self.hidden_size_per_attention_head = dist_utils.divide(
189+
self.hidden_size, self.num_heads)
190+
191+
if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE:
192+
self.enable_pad = True
193+
self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head
194+
self.half_origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head // 2
195+
self.half_pad_hidden_size_per_attention_head = (
196+
MAX_PAD_SIZE - self.hidden_size_per_attention_head) // 2
197+
self.hidden_size_per_attention_head = MAX_PAD_SIZE
198+
199+
def cal_cos_sin(self, rotary_pos_emb):
200+
cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2]
201+
sin = rotary_pos_emb.sin()
202+
if self.enable_pad:
203+
cos = torch.nn.functional.pad(
204+
cos, (0, self.half_pad_hidden_size_per_attention_head))
205+
sin = torch.nn.functional.pad(
206+
sin, (0, self.half_pad_hidden_size_per_attention_head))
207+
208+
if not self.interleaved:
209+
cos_new = torch.cat((cos, cos), dim=-1)
210+
sin_new = torch.cat((sin, sin), dim=-1)
211+
else:
212+
cos_new = rearrange(torch.stack((cos, cos), dim=-1),
213+
"... d two -> ...(d two)",
214+
two=2)
215+
sin_new = rearrange(torch.stack((sin, sin), dim=-1),
216+
"... d two -> ...(d two)",
217+
two=2)
218+
cos_new = cos_new.reshape(1, -1, 1,
219+
self.hidden_size_per_attention_head)
220+
sin_new = sin_new.reshape(1, -1, 1,
221+
self.hidden_size_per_attention_head)
222+
return cos_new, sin_new
223+
224+
def pad_qkv_bias(self, bias):
225+
first_half = bias.reshape(
226+
-1, 3, self.origin_hidden_size_per_attention_head
227+
)[:, :, :self.half_origin_hidden_size_per_attention_head]
228+
second_half = bias.reshape(
229+
-1, 3, self.origin_hidden_size_per_attention_head
230+
)[:, :, self.half_origin_hidden_size_per_attention_head:]
231+
first_half_padded = torch.nn.functional.pad(
232+
first_half, (0, self.half_pad_hidden_size_per_attention_head))
233+
second_half_padded = torch.nn.functional.pad(
234+
second_half, (0, self.half_pad_hidden_size_per_attention_head))
235+
bias_padded = torch.cat([first_half_padded, second_half_padded], dim=2)
236+
bias_final = bias_padded.reshape(-1)
237+
return bias_final
238+
239+
def pad_qkv_weight(self, data):
240+
qkv_weight_first_half = data.reshape(
241+
-1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size
242+
)[:, :, :self.half_origin_hidden_size_per_attention_head, :]
243+
qkv_weight_second_half = data.reshape(
244+
-1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size
245+
)[:, :, self.half_origin_hidden_size_per_attention_head:, :]
246+
247+
qkv_weight_first_half_padded = torch.nn.functional.pad(
248+
qkv_weight_first_half,
249+
(0, 0, 0, self.half_pad_hidden_size_per_attention_head))
250+
qkv_weight_second_half_padded = torch.nn.functional.pad(
251+
qkv_weight_second_half,
252+
(0, 0, 0, self.half_pad_hidden_size_per_attention_head))
253+
qkv_weight_padded = torch.cat(
254+
[qkv_weight_first_half_padded, qkv_weight_second_half_padded],
255+
dim=2)
256+
qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size)
257+
return qkv_weight_final
258+
259+
def pad_proj_weight(self, data):
260+
out_weight = torch.nn.functional.pad(
261+
data.reshape(self.hidden_size, -1,
262+
self.half_origin_hidden_size_per_attention_head),
263+
(0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape(
264+
self.hidden_size, -1)
265+
return out_weight
266+
267+
def load_weights(self, weights: Iterable[Tuple[str,
268+
torch.Tensor]]) -> Set[str]:
269+
stacked_params_mapping = [
270+
# (param_name, shard_name, shard_id)
271+
("qkv_proj", "q_proj", "q"),
272+
("qkv_proj", "k_proj", "k"),
273+
("qkv_proj", "v_proj", "v"),
274+
]
275+
params_dict = dict(self.named_parameters(remove_duplicate=False))
276+
loaded_params: Set[str] = set()
277+
for name, loaded_weight in weights:
278+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
279+
if weight_name not in name:
280+
continue
281+
name = name.replace(weight_name, param_name)
282+
283+
param = params_dict[name]
284+
weight_loader = param.weight_loader
285+
weight_loader(param, loaded_weight, shard_id)
286+
break
287+
else:
288+
param = params_dict[name]
289+
weight_loader = getattr(param, "weight_loader",
290+
default_weight_loader)
291+
weight_loader(param, loaded_weight)
292+
if ("attn.proj.weight" in name) and self.enable_pad:
293+
param.data = self.pad_proj_weight(param.data)
294+
if ("attn.qkv.weight" in name) and self.enable_pad:
295+
param.data = self.pad_qkv_weight(param.data)
296+
if ("attn.qkv.bias" in name) and self.enable_pad:
297+
param.data = self.pad_qkv_bias(param.data)
298+
loaded_params.add(name)
299+
return loaded_params
300+
301+
def forward(
302+
self,
303+
x: torch.Tensor,
304+
grid_thw: torch.Tensor,
305+
) -> torch.Tensor:
306+
# compute cu_seqlens
307+
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
308+
grid_thw[:,
309+
0]).cpu().to(torch.int32)
310+
311+
# patchify
312+
x = self.patch_embed(x)
313+
314+
# compute position embedding
315+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
316+
317+
# windows attention
318+
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
319+
cu_window_seqlens = torch.tensor(
320+
cu_window_seqlens,
321+
device=x.device,
322+
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32)
323+
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
324+
cu_window_seqlens = torch.diff(cu_window_seqlens).cpu().to(torch.int32)
325+
seq_len, _ = x.size()
326+
x = x.reshape(seq_len // self.spatial_merge_unit,
327+
self.spatial_merge_unit, -1)
328+
x = x[window_index, :, :]
329+
x = x.reshape(seq_len, -1)
330+
rotary_pos_emb = rotary_pos_emb.reshape(
331+
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
332+
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
333+
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
334+
335+
cos, sin = self.cal_cos_sin(rotary_pos_emb)
336+
337+
# transformers
338+
x = x.unsqueeze(1)
339+
for layer_num, blk in enumerate(self.blocks):
340+
if layer_num in self.fullatt_block_indexes:
341+
cu_seqlens_now = cu_seqlens
342+
else:
343+
cu_seqlens_now = cu_window_seqlens
344+
x = blk(x, cu_seqlens=cu_seqlens_now, cos=cos, sin=sin)
345+
346+
# adapter
347+
x = self.merger(x)
348+
reverse_indices = torch.argsort(window_index)
349+
x = x[reverse_indices, :]
350+
return x
351+
352+
353+
@MULTIMODAL_REGISTRY.register_processor(
354+
Qwen2_5_VLMultiModalProcessor,
355+
info=Qwen2_5_VLProcessingInfo,
356+
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
357+
class AscendQwen2_5_VLForConditionalGeneration(
358+
Qwen2_5_VLForConditionalGeneration):
359+
360+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
361+
super().__init__(vllm_config=vllm_config, prefix=prefix)
362+
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
363+
quant_config = vllm_config.quant_config
364+
self.visual = AscendQwen2_5_VisionTransformer(
365+
vision_config=config.vision_config,
366+
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
367+
quant_config=self._maybe_ignore_quant_config(quant_config),
368+
prefix=maybe_prefix(prefix, "visual"),
369+
)

0 commit comments

Comments
 (0)