Skip to content

Commit c874214

Browse files
[CherryPick] Add unpadded Qwen2.5-VL for verl scenario (#1095)
Add unpadded Qwen2.5-VL for verl scenario. When using vllm-ascend for verl scenario, set `USE_OPTIMIZED_QWEN2_5_VL` (default `1`) to `0` to use unpadded Qwen2.5-VL to avoid errors. This is cherry-picked from 0.7.3-dev Signed-off-by: shen-shanshan <467638484@qq.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: Shanshan Shen <467638484@qq.com>
1 parent b80a484 commit c874214

File tree

3 files changed

+288
-4
lines changed

3 files changed

+288
-4
lines changed

vllm_ascend/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@
128128
# enable `pin_memory` while creating a tensor using `torch.tensor`.
129129
"VLLM_ASCEND_ACL_OP_INIT_MODE":
130130
lambda: os.getenv("VLLM_ASCEND_ACL_OP_INIT_MODE", '1'),
131+
# Some models are optimized by vllm ascend. While in some case, e.g. rlhf
132+
# training, the optimized model may not be suitable. In this case, set this
133+
# value to False to disable the optimized model.
134+
"USE_OPTIMIZED_MODEL":
135+
lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))),
131136
}
132137

133138
# end-env-vars-definition

vllm_ascend/models/__init__.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,16 @@ def register_model():
2020
"Qwen2VLForConditionalGeneration",
2121
"vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration")
2222

23-
ModelRegistry.register_model(
24-
"Qwen2_5_VLForConditionalGeneration",
25-
"vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration"
26-
)
23+
if envs.USE_OPTIMIZED_MODEL:
24+
ModelRegistry.register_model(
25+
"Qwen2_5_VLForConditionalGeneration",
26+
"vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration"
27+
)
28+
else:
29+
ModelRegistry.register_model(
30+
"Qwen2_5_VLForConditionalGeneration",
31+
"vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen2_5_VLForConditionalGeneration_Without_Padding"
32+
)
2733

2834
if envs.VLLM_ASCEND_ENABLE_DBO:
2935
ModelRegistry.register_model(
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Adapted from vllm/model_executor/models/qwen2_5_vl.py
4+
# Copyright 2023 The vLLM team.
5+
#
6+
# This file is a part of the vllm-ascend project.
7+
#
8+
# Licensed under the Apache License, Version 2.0 (the "License");
9+
# you may not use this file except in compliance with the License.
10+
# You may obtain a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
20+
from functools import partial
21+
from typing import Callable, Optional
22+
23+
import torch
24+
import torch.nn as nn
25+
import torch.nn.functional as F
26+
import torch_npu
27+
from einops import rearrange
28+
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
29+
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
30+
from vllm.config import VllmConfig
31+
from vllm.distributed import parallel_state
32+
from vllm.distributed import utils as dist_utils
33+
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
34+
from vllm.model_executor.layers.layernorm import RMSNorm
35+
from vllm.model_executor.layers.quantization import QuantizationConfig
36+
from vllm.model_executor.models.qwen2_5_vl import (
37+
Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed,
38+
Qwen2_5_VisionTransformer, Qwen2_5_VLDummyInputsBuilder,
39+
Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLMultiModalProcessor,
40+
Qwen2_5_VLProcessingInfo)
41+
from vllm.model_executor.models.utils import maybe_prefix
42+
from vllm.multimodal import MULTIMODAL_REGISTRY
43+
44+
45+
class AscendQwen2_5_VisionAttention_Without_Padding(Qwen2_5_VisionAttention):
46+
47+
def __init__(
48+
self,
49+
embed_dim: int,
50+
num_heads: int,
51+
projection_size: int,
52+
quant_config: Optional[QuantizationConfig] = None,
53+
prefix: str = "",
54+
) -> None:
55+
super().__init__(
56+
embed_dim,
57+
num_heads,
58+
projection_size,
59+
quant_config,
60+
prefix,
61+
)
62+
self.embed_dim = embed_dim
63+
self.hidden_size_per_attention_head = dist_utils.divide(
64+
projection_size, num_heads)
65+
66+
def forward(
67+
self,
68+
x: torch.Tensor,
69+
cu_seqlens: torch.Tensor,
70+
cos: torch.Tensor,
71+
sin: torch.Tensor,
72+
) -> torch.Tensor:
73+
# [s, b, c] --> [s, b, head * 3 * head_dim]
74+
x, _ = self.qkv(x)
75+
76+
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
77+
q, k, v = self.split_qkv(x)
78+
batch_size = q.shape[1]
79+
80+
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
81+
for x in (q, k, v))
82+
q = torch_npu.npu_rotary_mul(q, cos, sin)
83+
k = torch_npu.npu_rotary_mul(k, cos, sin)
84+
85+
q, k, v = [
86+
rearrange(x, "b s h d -> (b s) h d").contiguous()
87+
for x in (q, k, v)
88+
]
89+
90+
context_layer = torch.torch.empty_like(q)
91+
92+
# operator requires pta version >= 2.5.1.dev20250226
93+
torch_npu._npu_flash_attention_unpad(
94+
query=q,
95+
key=k,
96+
value=v,
97+
seq_len=cu_seqlens,
98+
scale_value=self.hidden_size_per_attention_head**-0.5,
99+
num_heads=self.num_attention_heads_per_partition,
100+
num_kv_heads=self.num_attention_heads_per_partition,
101+
out=context_layer)
102+
103+
context_layer = rearrange(context_layer,
104+
"(b s) h d -> s b (h d)",
105+
b=batch_size).contiguous()
106+
107+
output, _ = self.proj(context_layer)
108+
return output
109+
110+
111+
class AscendQwen2_5_VisionBlock_Without_Padding(Qwen2_5_VisionBlock):
112+
113+
def __init__(
114+
self,
115+
dim: int,
116+
num_heads: int,
117+
mlp_hidden_dim: int,
118+
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
119+
norm_layer: Optional[Callable[[int], nn.Module]] = None,
120+
quant_config: Optional[QuantizationConfig] = None,
121+
prefix: str = "",
122+
) -> None:
123+
super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer,
124+
quant_config, prefix)
125+
self.attn = AscendQwen2_5_VisionAttention_Without_Padding(
126+
embed_dim=dim,
127+
num_heads=num_heads,
128+
projection_size=dim,
129+
quant_config=quant_config,
130+
prefix=f"{prefix}.attn")
131+
132+
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
133+
cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
134+
x = x + self.attn(
135+
self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin)
136+
137+
x = x + self.mlp(self.norm2(x))
138+
return x
139+
140+
141+
class AscendQwen2_5_VisionPatchEmbed_Without_Padding(Qwen2_5_VisionPatchEmbed):
142+
143+
def forward(self, x: torch.Tensor) -> torch.Tensor:
144+
x = x.matmul(
145+
self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1))
146+
return x
147+
148+
149+
class AscendQwen2_5_VisionTransformer_Without_Padding(Qwen2_5_VisionTransformer
150+
):
151+
152+
def __init__(
153+
self,
154+
vision_config: Qwen2_5_VLVisionConfig,
155+
norm_eps: float = 1e-6,
156+
quant_config: Optional[QuantizationConfig] = None,
157+
prefix: str = "",
158+
interleaved=False,
159+
) -> None:
160+
super().__init__(vision_config, norm_eps, quant_config, prefix)
161+
norm_layer = partial(RMSNorm, eps=norm_eps)
162+
self.interleaved = interleaved
163+
self.patch_embed = AscendQwen2_5_VisionPatchEmbed_Without_Padding(
164+
patch_size=vision_config.patch_size,
165+
temporal_patch_size=vision_config.temporal_patch_size,
166+
in_channels=vision_config.in_channels,
167+
hidden_size=self.hidden_size,
168+
)
169+
self.blocks = nn.ModuleList([
170+
AscendQwen2_5_VisionBlock_Without_Padding(
171+
dim=self.hidden_size,
172+
num_heads=self.num_heads,
173+
mlp_hidden_dim=vision_config.intermediate_size,
174+
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
175+
norm_layer=norm_layer,
176+
quant_config=quant_config,
177+
prefix=f"{prefix}.blocks.{layer_idx}")
178+
for layer_idx in range(vision_config.depth)
179+
])
180+
self.tp_size = parallel_state.get_tensor_model_parallel_world_size()
181+
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
182+
self.hidden_size_per_attention_head = dist_utils.divide(
183+
self.hidden_size, self.num_heads)
184+
185+
def cal_cos_sin(self, rotary_pos_emb):
186+
cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2]
187+
sin = rotary_pos_emb.sin()
188+
189+
if not self.interleaved:
190+
cos_new = torch.cat((cos, cos), dim=-1)
191+
sin_new = torch.cat((sin, sin), dim=-1)
192+
else:
193+
cos_new = rearrange(torch.stack((cos, cos), dim=-1),
194+
"... d two -> ...(d two)",
195+
two=2)
196+
sin_new = rearrange(torch.stack((sin, sin), dim=-1),
197+
"... d two -> ...(d two)",
198+
two=2)
199+
cos_new = cos_new.reshape(1, -1, 1,
200+
self.hidden_size_per_attention_head)
201+
sin_new = sin_new.reshape(1, -1, 1,
202+
self.hidden_size_per_attention_head)
203+
return cos_new, sin_new
204+
205+
def forward(
206+
self,
207+
x: torch.Tensor,
208+
grid_thw: torch.Tensor,
209+
) -> torch.Tensor:
210+
# compute cu_seqlens
211+
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
212+
grid_thw[:,
213+
0]).cpu().to(torch.int32)
214+
215+
# patchify
216+
x = self.patch_embed(x)
217+
218+
# compute position embedding
219+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
220+
221+
# windows attention
222+
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
223+
cu_window_seqlens = torch.tensor(
224+
cu_window_seqlens,
225+
device=x.device,
226+
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32)
227+
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
228+
cu_window_seqlens = torch.diff(cu_window_seqlens).cpu().to(torch.int32)
229+
seq_len, _ = x.size()
230+
x = x.reshape(seq_len // self.spatial_merge_unit,
231+
self.spatial_merge_unit, -1)
232+
x = x[window_index, :, :]
233+
x = x.reshape(seq_len, -1)
234+
rotary_pos_emb = rotary_pos_emb.reshape(
235+
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
236+
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
237+
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
238+
239+
cos, sin = self.cal_cos_sin(rotary_pos_emb)
240+
241+
# transformers
242+
x = x.unsqueeze(1)
243+
for layer_num, blk in enumerate(self.blocks):
244+
if layer_num in self.fullatt_block_indexes:
245+
cu_seqlens_now = cu_seqlens
246+
else:
247+
cu_seqlens_now = cu_window_seqlens
248+
x = blk(x, cu_seqlens=cu_seqlens_now, cos=cos, sin=sin)
249+
250+
# adapter
251+
x = self.merger(x)
252+
reverse_indices = torch.argsort(window_index)
253+
x = x[reverse_indices, :]
254+
return x
255+
256+
257+
@MULTIMODAL_REGISTRY.register_processor(
258+
Qwen2_5_VLMultiModalProcessor,
259+
info=Qwen2_5_VLProcessingInfo,
260+
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
261+
class AscendQwen2_5_VLForConditionalGeneration_Without_Padding(
262+
Qwen2_5_VLForConditionalGeneration):
263+
264+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
265+
super().__init__(vllm_config=vllm_config, prefix=prefix)
266+
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
267+
quant_config = vllm_config.quant_config
268+
self.visual = AscendQwen2_5_VisionTransformer_Without_Padding(
269+
vision_config=config.vision_config,
270+
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
271+
quant_config=self._maybe_ignore_quant_config(quant_config),
272+
prefix=maybe_prefix(prefix, "visual"),
273+
)

0 commit comments

Comments
 (0)