1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # Adapted from vllm/model_executor/models/qwen2_5_vl.py
4
+ # Copyright 2023 The vLLM team.
5
+ #
6
+ # This file is a part of the vllm-ascend project.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+ from functools import partial
21
+ from typing import Callable , Iterable , Optional , Set , Tuple
22
+
23
+ import torch
24
+ import torch .nn as nn
25
+ import torch .nn .functional as F
26
+ import torch_npu
27
+ from einops import rearrange
28
+ from transformers .models .qwen2_5_vl .configuration_qwen2_5_vl import (
29
+ Qwen2_5_VLConfig , Qwen2_5_VLVisionConfig )
30
+ from vllm .config import VllmConfig
31
+ from vllm .distributed import parallel_state
32
+ from vllm .distributed import utils as dist_utils
33
+ from vllm .model_executor .layers .activation import _ACTIVATION_REGISTRY
34
+ from vllm .model_executor .layers .layernorm import RMSNorm
35
+ from vllm .model_executor .layers .quantization import QuantizationConfig
36
+ from vllm .model_executor .model_loader .weight_utils import default_weight_loader
37
+ from vllm .model_executor .models .qwen2_5_vl import (
38
+ Qwen2_5_VisionAttention , Qwen2_5_VisionBlock , Qwen2_5_VisionPatchEmbed ,
39
+ Qwen2_5_VisionTransformer , Qwen2_5_VLDummyInputsBuilder ,
40
+ Qwen2_5_VLForConditionalGeneration , Qwen2_5_VLMultiModalProcessor ,
41
+ Qwen2_5_VLProcessingInfo )
42
+ from vllm .model_executor .models .utils import maybe_prefix
43
+ from vllm .multimodal import MULTIMODAL_REGISTRY
44
+
45
+ MIN_PAD_SIZE = 64
46
+ MAX_PAD_SIZE = 128
47
+
48
+
49
+ class AscendQwen2_5_VisionAttention (Qwen2_5_VisionAttention ):
50
+
51
+ def __init__ (
52
+ self ,
53
+ embed_dim : int ,
54
+ num_heads : int ,
55
+ projection_size : int ,
56
+ quant_config : Optional [QuantizationConfig ] = None ,
57
+ prefix : str = "" ,
58
+ ) -> None :
59
+ super ().__init__ (
60
+ embed_dim ,
61
+ num_heads ,
62
+ projection_size ,
63
+ quant_config ,
64
+ prefix ,
65
+ )
66
+ self .embed_dim = embed_dim
67
+ self .hidden_size_per_attention_head = dist_utils .divide (
68
+ projection_size , num_heads )
69
+ if self .hidden_size_per_attention_head > MIN_PAD_SIZE and self .hidden_size_per_attention_head < MAX_PAD_SIZE :
70
+ self .hidden_size_per_attention_head = MAX_PAD_SIZE
71
+
72
+ def forward (
73
+ self ,
74
+ x : torch .Tensor ,
75
+ cu_seqlens : torch .Tensor ,
76
+ cos : torch .Tensor ,
77
+ sin : torch .Tensor ,
78
+ ) -> torch .Tensor :
79
+ # [s, b, c] --> [s, b, head * 3 * head_dim]
80
+ x , _ = self .qkv (x )
81
+
82
+ # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
83
+ q , k , v = self .split_qkv (x )
84
+ batch_size = q .shape [1 ]
85
+
86
+ q , k , v = (rearrange (x , "s b ... -> b s ..." ).contiguous ()
87
+ for x in (q , k , v ))
88
+ q = torch_npu .npu_rotary_mul (q , cos , sin )
89
+ k = torch_npu .npu_rotary_mul (k , cos , sin )
90
+
91
+ q , k , v = [
92
+ rearrange (x , "b s h d -> (b s) h d" ).contiguous ()
93
+ for x in (q , k , v )
94
+ ]
95
+
96
+ context_layer = torch .torch .empty_like (q )
97
+
98
+ # operator requires pta version >= 2.5.1.dev20250226
99
+ torch_npu ._npu_flash_attention_unpad (
100
+ query = q ,
101
+ key = k ,
102
+ value = v ,
103
+ seq_len = cu_seqlens ,
104
+ scale_value = self .hidden_size_per_attention_head ** - 0.5 ,
105
+ num_heads = self .num_attention_heads_per_partition ,
106
+ num_kv_heads = self .num_attention_heads_per_partition ,
107
+ out = context_layer )
108
+
109
+ context_layer = rearrange (context_layer ,
110
+ "(b s) h d -> s b (h d)" ,
111
+ b = batch_size ).contiguous ()
112
+
113
+ output , _ = self .proj (context_layer )
114
+ return output
115
+
116
+
117
+ class AscendQwen2_5_VisionBlock (Qwen2_5_VisionBlock ):
118
+
119
+ def __init__ (
120
+ self ,
121
+ dim : int ,
122
+ num_heads : int ,
123
+ mlp_hidden_dim : int ,
124
+ act_fn : Callable [[torch .Tensor ], torch .Tensor ] = F .silu ,
125
+ norm_layer : Optional [Callable [[int ], nn .Module ]] = None ,
126
+ quant_config : Optional [QuantizationConfig ] = None ,
127
+ prefix : str = "" ,
128
+ ) -> None :
129
+ super ().__init__ (dim , num_heads , mlp_hidden_dim , act_fn , norm_layer ,
130
+ quant_config , prefix )
131
+ self .attn = AscendQwen2_5_VisionAttention (embed_dim = dim ,
132
+ num_heads = num_heads ,
133
+ projection_size = dim ,
134
+ quant_config = quant_config ,
135
+ prefix = f"{ prefix } .attn" )
136
+
137
+ def forward (self , x : torch .Tensor , cu_seqlens : torch .Tensor ,
138
+ cos : torch .Tensor , sin : torch .Tensor ) -> torch .Tensor :
139
+ x = x + self .attn (
140
+ self .norm1 (x ), cu_seqlens = cu_seqlens , cos = cos , sin = sin )
141
+
142
+ x = x + self .mlp (self .norm2 (x ))
143
+ return x
144
+
145
+
146
+ class AscendQwen2_5_VisionPatchEmbed (Qwen2_5_VisionPatchEmbed ):
147
+
148
+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
149
+ x = x .matmul (
150
+ self .proj .weight .data .view (self .hidden_size , - 1 ).transpose (0 , 1 ))
151
+ return x
152
+
153
+
154
+ class AscendQwen2_5_VisionTransformer (Qwen2_5_VisionTransformer ):
155
+
156
+ def __init__ (
157
+ self ,
158
+ vision_config : Qwen2_5_VLVisionConfig ,
159
+ norm_eps : float = 1e-6 ,
160
+ quant_config : Optional [QuantizationConfig ] = None ,
161
+ prefix : str = "" ,
162
+ interleaved = False ,
163
+ ) -> None :
164
+ super ().__init__ (vision_config , norm_eps , quant_config , prefix )
165
+ norm_layer = partial (RMSNorm , eps = norm_eps )
166
+ self .interleaved = interleaved
167
+ self .patch_embed = AscendQwen2_5_VisionPatchEmbed (
168
+ patch_size = vision_config .patch_size ,
169
+ temporal_patch_size = vision_config .temporal_patch_size ,
170
+ in_channels = vision_config .in_channels ,
171
+ hidden_size = self .hidden_size ,
172
+ )
173
+ self .blocks = nn .ModuleList ([
174
+ AscendQwen2_5_VisionBlock (
175
+ dim = self .hidden_size ,
176
+ num_heads = self .num_heads ,
177
+ mlp_hidden_dim = vision_config .intermediate_size ,
178
+ act_fn = _ACTIVATION_REGISTRY [vision_config .hidden_act ],
179
+ norm_layer = norm_layer ,
180
+ quant_config = quant_config ,
181
+ prefix = f"{ prefix } .blocks.{ layer_idx } " )
182
+ for layer_idx in range (vision_config .depth )
183
+ ])
184
+ self .tp_size = parallel_state .get_tensor_model_parallel_world_size ()
185
+ self .tp_rank = parallel_state .get_tensor_model_parallel_rank ()
186
+ self .hidden_size_per_attention_head = dist_utils .divide (
187
+ self .hidden_size , self .num_heads )
188
+
189
+ if self .hidden_size_per_attention_head > MIN_PAD_SIZE and self .hidden_size_per_attention_head < MAX_PAD_SIZE :
190
+ self .origin_hidden_size_per_attention_head = self .hidden_size_per_attention_head
191
+ self .half_origin_hidden_size_per_attention_head = self .hidden_size_per_attention_head // 2
192
+ self .half_pad_hidden_size_per_attention_head = (
193
+ MAX_PAD_SIZE - self .hidden_size_per_attention_head ) // 2
194
+ self .hidden_size_per_attention_head = MAX_PAD_SIZE
195
+
196
+ def cal_cos_sin (self , rotary_pos_emb ):
197
+ cos = rotary_pos_emb .cos () # [seqlen, rotary_dim / 2]
198
+ sin = rotary_pos_emb .sin ()
199
+ cos = torch .nn .functional .pad (
200
+ cos , (0 , self .half_pad_hidden_size_per_attention_head ))
201
+ sin = torch .nn .functional .pad (
202
+ sin , (0 , self .half_pad_hidden_size_per_attention_head ))
203
+
204
+ if not self .interleaved :
205
+ cos_new = torch .cat ((cos , cos ), dim = - 1 )
206
+ sin_new = torch .cat ((sin , sin ), dim = - 1 )
207
+ else :
208
+ cos_new = rearrange (torch .stack ((cos , cos ), dim = - 1 ),
209
+ "... d two -> ...(d two)" ,
210
+ two = 2 )
211
+ sin_new = rearrange (torch .stack ((sin , sin ), dim = - 1 ),
212
+ "... d two -> ...(d two)" ,
213
+ two = 2 )
214
+ cos_new = cos_new .reshape (1 , - 1 , 1 ,
215
+ self .hidden_size_per_attention_head )
216
+ sin_new = sin_new .reshape (1 , - 1 , 1 ,
217
+ self .hidden_size_per_attention_head )
218
+ return cos_new , sin_new
219
+
220
+ def pad_qkv_bias (self , bias ):
221
+ first_half = bias .reshape (
222
+ - 1 , 3 , self .origin_hidden_size_per_attention_head
223
+ )[:, :, :self .half_origin_hidden_size_per_attention_head ]
224
+ second_half = bias .reshape (
225
+ - 1 , 3 , self .origin_hidden_size_per_attention_head
226
+ )[:, :, self .half_origin_hidden_size_per_attention_head :]
227
+ first_half_padded = torch .nn .functional .pad (
228
+ first_half , (0 , self .half_pad_hidden_size_per_attention_head ))
229
+ second_half_padded = torch .nn .functional .pad (
230
+ second_half , (0 , self .half_pad_hidden_size_per_attention_head ))
231
+ bias_padded = torch .cat ([first_half_padded , second_half_padded ], dim = 2 )
232
+ bias_final = bias_padded .reshape (- 1 )
233
+ return bias_final
234
+
235
+ def pad_qkv_weight (self , data ):
236
+ qkv_weight_first_half = data .reshape (
237
+ - 1 , 3 , self .origin_hidden_size_per_attention_head , self .hidden_size
238
+ )[:, :, :self .half_origin_hidden_size_per_attention_head , :]
239
+ qkv_weight_second_half = data .reshape (
240
+ - 1 , 3 , self .origin_hidden_size_per_attention_head , self .hidden_size
241
+ )[:, :, self .half_origin_hidden_size_per_attention_head :, :]
242
+
243
+ qkv_weight_first_half_padded = torch .nn .functional .pad (
244
+ qkv_weight_first_half ,
245
+ (0 , 0 , 0 , self .half_pad_hidden_size_per_attention_head ))
246
+ qkv_weight_second_half_padded = torch .nn .functional .pad (
247
+ qkv_weight_second_half ,
248
+ (0 , 0 , 0 , self .half_pad_hidden_size_per_attention_head ))
249
+ qkv_weight_padded = torch .cat (
250
+ [qkv_weight_first_half_padded , qkv_weight_second_half_padded ],
251
+ dim = 2 )
252
+ qkv_weight_final = qkv_weight_padded .reshape (- 1 , self .hidden_size )
253
+ return qkv_weight_final
254
+
255
+ def pad_proj_weight (self , data ):
256
+ out_weight = torch .nn .functional .pad (
257
+ data .reshape (self .hidden_size , - 1 ,
258
+ self .half_origin_hidden_size_per_attention_head ),
259
+ (0 , self .half_pad_hidden_size_per_attention_head , 0 , 0 )).reshape (
260
+ self .hidden_size , - 1 )
261
+ return out_weight
262
+
263
+ def load_weights (self , weights : Iterable [Tuple [str ,
264
+ torch .Tensor ]]) -> Set [str ]:
265
+ stacked_params_mapping = [
266
+ # (param_name, shard_name, shard_id)
267
+ ("qkv_proj" , "q_proj" , "q" ),
268
+ ("qkv_proj" , "k_proj" , "k" ),
269
+ ("qkv_proj" , "v_proj" , "v" ),
270
+ ]
271
+ params_dict = dict (self .named_parameters (remove_duplicate = False ))
272
+ loaded_params : Set [str ] = set ()
273
+ for name , loaded_weight in weights :
274
+ for (param_name , weight_name , shard_id ) in stacked_params_mapping :
275
+ if weight_name not in name :
276
+ continue
277
+ name = name .replace (weight_name , param_name )
278
+
279
+ param = params_dict [name ]
280
+ weight_loader = param .weight_loader
281
+ weight_loader (param , loaded_weight , shard_id )
282
+ break
283
+ else :
284
+ param = params_dict [name ]
285
+ weight_loader = getattr (param , "weight_loader" ,
286
+ default_weight_loader )
287
+ weight_loader (param , loaded_weight )
288
+ if ("attn.proj.weight" in name ):
289
+ param .data = self .pad_proj_weight (param .data )
290
+ if ("attn.qkv.weight" in name ):
291
+ param .data = self .pad_qkv_weight (param .data )
292
+ if ("attn.qkv.bias" in name ):
293
+ param .data = self .pad_qkv_bias (param .data )
294
+ loaded_params .add (name )
295
+ return loaded_params
296
+
297
+ def forward (
298
+ self ,
299
+ x : torch .Tensor ,
300
+ grid_thw : torch .Tensor ,
301
+ ) -> torch .Tensor :
302
+ # compute cu_seqlens
303
+ cu_seqlens = torch .repeat_interleave (grid_thw [:, 1 ] * grid_thw [:, 2 ],
304
+ grid_thw [:,
305
+ 0 ]).cpu ().to (torch .int32 )
306
+
307
+ # patchify
308
+ x = self .patch_embed (x )
309
+
310
+ # compute position embedding
311
+ rotary_pos_emb = self .rot_pos_emb (grid_thw )
312
+
313
+ # windows attention
314
+ window_index , cu_window_seqlens = self .get_window_index (grid_thw )
315
+ cu_window_seqlens = torch .tensor (
316
+ cu_window_seqlens ,
317
+ device = x .device ,
318
+ dtype = grid_thw .dtype if torch .jit .is_tracing () else torch .int32 )
319
+ cu_window_seqlens = torch .unique_consecutive (cu_window_seqlens )
320
+ cu_window_seqlens = torch .diff (cu_window_seqlens ).cpu ().to (torch .int32 )
321
+ seq_len , _ = x .size ()
322
+ x = x .reshape (seq_len // self .spatial_merge_unit ,
323
+ self .spatial_merge_unit , - 1 )
324
+ x = x [window_index , :, :]
325
+ x = x .reshape (seq_len , - 1 )
326
+ rotary_pos_emb = rotary_pos_emb .reshape (
327
+ seq_len // self .spatial_merge_unit , self .spatial_merge_unit , - 1 )
328
+ rotary_pos_emb = rotary_pos_emb [window_index , :, :]
329
+ rotary_pos_emb = rotary_pos_emb .reshape (seq_len , - 1 )
330
+
331
+ cos , sin = self .cal_cos_sin (rotary_pos_emb )
332
+
333
+ # transformers
334
+ x = x .unsqueeze (1 )
335
+ for layer_num , blk in enumerate (self .blocks ):
336
+ if layer_num in self .fullatt_block_indexes :
337
+ cu_seqlens_now = cu_seqlens
338
+ else :
339
+ cu_seqlens_now = cu_window_seqlens
340
+ x = blk (x , cu_seqlens = cu_seqlens_now , cos = cos , sin = sin )
341
+
342
+ # adapter
343
+ x = self .merger (x )
344
+ reverse_indices = torch .argsort (window_index )
345
+ x = x [reverse_indices , :]
346
+ return x
347
+
348
+
349
+ @MULTIMODAL_REGISTRY .register_processor (
350
+ Qwen2_5_VLMultiModalProcessor ,
351
+ info = Qwen2_5_VLProcessingInfo ,
352
+ dummy_inputs = Qwen2_5_VLDummyInputsBuilder )
353
+ class AscendQwen2_5_VLForConditionalGeneration (
354
+ Qwen2_5_VLForConditionalGeneration ):
355
+
356
+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
357
+ super ().__init__ (vllm_config = vllm_config , prefix = prefix )
358
+ config : Qwen2_5_VLConfig = vllm_config .model_config .hf_config
359
+ quant_config = vllm_config .quant_config
360
+ self .visual = AscendQwen2_5_VisionTransformer (
361
+ vision_config = config .vision_config ,
362
+ norm_eps = getattr (config , "rms_norm_eps" , 1e-6 ),
363
+ quant_config = self ._maybe_ignore_quant_config (quant_config ),
364
+ prefix = maybe_prefix (prefix , "visual" ),
365
+ )
0 commit comments