1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # Adapted from vllm/model_executor/models/qwen2_5_vl.py
4
+ # Copyright 2023 The vLLM team.
5
+ #
6
+ # This file is a part of the vllm-ascend project.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+ from functools import partial
21
+ from typing import Callable , Iterable , Optional , Set , Tuple
22
+
23
+ import torch
24
+ import torch .nn as nn
25
+ import torch .nn .functional as F
26
+ import torch_npu
27
+ from einops import rearrange
28
+ from transformers .models .qwen2_5_vl .configuration_qwen2_5_vl import (
29
+ Qwen2_5_VLConfig , Qwen2_5_VLVisionConfig )
30
+ from vllm .config import VllmConfig
31
+ from vllm .distributed import parallel_state
32
+ from vllm .distributed import utils as dist_utils
33
+ from vllm .model_executor .layers .activation import _ACTIVATION_REGISTRY
34
+ from vllm .model_executor .layers .layernorm import RMSNorm
35
+ from vllm .model_executor .layers .quantization import QuantizationConfig
36
+ from vllm .model_executor .model_loader .weight_utils import default_weight_loader
37
+ from vllm .model_executor .models .qwen2_5_vl import (
38
+ Qwen2_5_VisionAttention , Qwen2_5_VisionBlock , Qwen2_5_VisionPatchEmbed ,
39
+ Qwen2_5_VisionTransformer , Qwen2_5_VLDummyInputsBuilder ,
40
+ Qwen2_5_VLForConditionalGeneration , Qwen2_5_VLMultiModalProcessor ,
41
+ Qwen2_5_VLProcessingInfo )
42
+ from vllm .model_executor .models .utils import maybe_prefix
43
+ from vllm .multimodal import MULTIMODAL_REGISTRY
44
+
45
+ MIN_PAD_SIZE = 64 # min_size to pad weight
46
+ MAX_PAD_SIZE = 128 # max_size to pad weight
47
+
48
+
49
+ class AscendQwen2_5_VisionAttention (Qwen2_5_VisionAttention ):
50
+
51
+ def __init__ (
52
+ self ,
53
+ embed_dim : int ,
54
+ num_heads : int ,
55
+ projection_size : int ,
56
+ quant_config : Optional [QuantizationConfig ] = None ,
57
+ prefix : str = "" ,
58
+ ) -> None :
59
+ super ().__init__ (
60
+ embed_dim ,
61
+ num_heads ,
62
+ projection_size ,
63
+ quant_config ,
64
+ prefix ,
65
+ )
66
+ self .embed_dim = embed_dim
67
+ self .hidden_size_per_attention_head = dist_utils .divide (
68
+ projection_size , num_heads )
69
+ self .origin_hidden_size_per_attention_head = self .hidden_size_per_attention_head
70
+ if self .hidden_size_per_attention_head > MIN_PAD_SIZE and self .hidden_size_per_attention_head < MAX_PAD_SIZE :
71
+ self .hidden_size_per_attention_head = MAX_PAD_SIZE
72
+
73
+ def forward (
74
+ self ,
75
+ x : torch .Tensor ,
76
+ cu_seqlens : torch .Tensor ,
77
+ cos : torch .Tensor ,
78
+ sin : torch .Tensor ,
79
+ ) -> torch .Tensor :
80
+ # [s, b, c] --> [s, b, head * 3 * head_dim]
81
+ x , _ = self .qkv (x )
82
+
83
+ # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
84
+ q , k , v = self .split_qkv (x )
85
+ batch_size = q .shape [1 ]
86
+
87
+ q , k , v = (rearrange (x , "s b ... -> b s ..." ).contiguous ()
88
+ for x in (q , k , v ))
89
+ q = torch_npu .npu_rotary_mul (q , cos , sin )
90
+ k = torch_npu .npu_rotary_mul (k , cos , sin )
91
+
92
+ q , k , v = [
93
+ rearrange (x , "b s h d -> (b s) h d" ).contiguous ()
94
+ for x in (q , k , v )
95
+ ]
96
+
97
+ context_layer = torch .torch .empty_like (q )
98
+
99
+ # operator requires pta version >= 2.5.1
100
+ torch_npu ._npu_flash_attention_unpad (
101
+ query = q ,
102
+ key = k ,
103
+ value = v ,
104
+ seq_len = cu_seqlens ,
105
+ scale_value = self .origin_hidden_size_per_attention_head ** - 0.5 ,
106
+ num_heads = self .num_attention_heads_per_partition ,
107
+ num_kv_heads = self .num_attention_heads_per_partition ,
108
+ out = context_layer )
109
+
110
+ context_layer = rearrange (context_layer ,
111
+ "(b s) h d -> s b (h d)" ,
112
+ b = batch_size ).contiguous ()
113
+
114
+ output , _ = self .proj (context_layer )
115
+ return output
116
+
117
+
118
+ class AscendQwen2_5_VisionBlock (Qwen2_5_VisionBlock ):
119
+
120
+ def __init__ (
121
+ self ,
122
+ dim : int ,
123
+ num_heads : int ,
124
+ mlp_hidden_dim : int ,
125
+ act_fn : Callable [[torch .Tensor ], torch .Tensor ] = F .silu ,
126
+ norm_layer : Optional [Callable [[int ], nn .Module ]] = None ,
127
+ quant_config : Optional [QuantizationConfig ] = None ,
128
+ prefix : str = "" ,
129
+ ) -> None :
130
+ super ().__init__ (dim , num_heads , mlp_hidden_dim , act_fn , norm_layer ,
131
+ quant_config , prefix )
132
+ self .attn = AscendQwen2_5_VisionAttention (embed_dim = dim ,
133
+ num_heads = num_heads ,
134
+ projection_size = dim ,
135
+ quant_config = quant_config ,
136
+ prefix = f"{ prefix } .attn" )
137
+
138
+ def forward (self , x : torch .Tensor , cu_seqlens : torch .Tensor ,
139
+ cos : torch .Tensor , sin : torch .Tensor ) -> torch .Tensor :
140
+ x = x + self .attn (
141
+ self .norm1 (x ), cu_seqlens = cu_seqlens , cos = cos , sin = sin )
142
+
143
+ x = x + self .mlp (self .norm2 (x ))
144
+ return x
145
+
146
+
147
+ class AscendQwen2_5_VisionPatchEmbed (Qwen2_5_VisionPatchEmbed ):
148
+
149
+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
150
+ x = x .matmul (
151
+ self .proj .weight .data .view (self .hidden_size , - 1 ).transpose (0 , 1 ))
152
+ return x
153
+
154
+
155
+ class AscendQwen2_5_VisionTransformer (Qwen2_5_VisionTransformer ):
156
+
157
+ def __init__ (
158
+ self ,
159
+ vision_config : Qwen2_5_VLVisionConfig ,
160
+ norm_eps : float = 1e-6 ,
161
+ quant_config : Optional [QuantizationConfig ] = None ,
162
+ prefix : str = "" ,
163
+ interleaved = False ,
164
+ ) -> None :
165
+ super ().__init__ (vision_config , norm_eps , quant_config , prefix )
166
+ norm_layer = partial (RMSNorm , eps = norm_eps )
167
+ self .interleaved = interleaved
168
+ self .enable_pad = False
169
+ self .patch_embed = AscendQwen2_5_VisionPatchEmbed (
170
+ patch_size = vision_config .patch_size ,
171
+ temporal_patch_size = vision_config .temporal_patch_size ,
172
+ in_channels = vision_config .in_channels ,
173
+ hidden_size = self .hidden_size ,
174
+ )
175
+ self .blocks = nn .ModuleList ([
176
+ AscendQwen2_5_VisionBlock (
177
+ dim = self .hidden_size ,
178
+ num_heads = self .num_heads ,
179
+ mlp_hidden_dim = vision_config .intermediate_size ,
180
+ act_fn = _ACTIVATION_REGISTRY [vision_config .hidden_act ],
181
+ norm_layer = norm_layer ,
182
+ quant_config = quant_config ,
183
+ prefix = f"{ prefix } .blocks.{ layer_idx } " )
184
+ for layer_idx in range (vision_config .depth )
185
+ ])
186
+ self .tp_size = parallel_state .get_tensor_model_parallel_world_size ()
187
+ self .tp_rank = parallel_state .get_tensor_model_parallel_rank ()
188
+ self .hidden_size_per_attention_head = dist_utils .divide (
189
+ self .hidden_size , self .num_heads )
190
+
191
+ if self .hidden_size_per_attention_head > MIN_PAD_SIZE and self .hidden_size_per_attention_head < MAX_PAD_SIZE :
192
+ self .enable_pad = True
193
+ self .origin_hidden_size_per_attention_head = self .hidden_size_per_attention_head
194
+ self .half_origin_hidden_size_per_attention_head = self .hidden_size_per_attention_head // 2
195
+ self .half_pad_hidden_size_per_attention_head = (
196
+ MAX_PAD_SIZE - self .hidden_size_per_attention_head ) // 2
197
+ self .hidden_size_per_attention_head = MAX_PAD_SIZE
198
+
199
+ def cal_cos_sin (self , rotary_pos_emb ):
200
+ cos = rotary_pos_emb .cos () # [seqlen, rotary_dim / 2]
201
+ sin = rotary_pos_emb .sin ()
202
+ if self .enable_pad :
203
+ cos = torch .nn .functional .pad (
204
+ cos , (0 , self .half_pad_hidden_size_per_attention_head ))
205
+ sin = torch .nn .functional .pad (
206
+ sin , (0 , self .half_pad_hidden_size_per_attention_head ))
207
+
208
+ if not self .interleaved :
209
+ cos_new = torch .cat ((cos , cos ), dim = - 1 )
210
+ sin_new = torch .cat ((sin , sin ), dim = - 1 )
211
+ else :
212
+ cos_new = rearrange (torch .stack ((cos , cos ), dim = - 1 ),
213
+ "... d two -> ...(d two)" ,
214
+ two = 2 )
215
+ sin_new = rearrange (torch .stack ((sin , sin ), dim = - 1 ),
216
+ "... d two -> ...(d two)" ,
217
+ two = 2 )
218
+ cos_new = cos_new .reshape (1 , - 1 , 1 ,
219
+ self .hidden_size_per_attention_head )
220
+ sin_new = sin_new .reshape (1 , - 1 , 1 ,
221
+ self .hidden_size_per_attention_head )
222
+ return cos_new , sin_new
223
+
224
+ def pad_qkv_bias (self , bias ):
225
+ first_half = bias .reshape (
226
+ - 1 , 3 , self .origin_hidden_size_per_attention_head
227
+ )[:, :, :self .half_origin_hidden_size_per_attention_head ]
228
+ second_half = bias .reshape (
229
+ - 1 , 3 , self .origin_hidden_size_per_attention_head
230
+ )[:, :, self .half_origin_hidden_size_per_attention_head :]
231
+ first_half_padded = torch .nn .functional .pad (
232
+ first_half , (0 , self .half_pad_hidden_size_per_attention_head ))
233
+ second_half_padded = torch .nn .functional .pad (
234
+ second_half , (0 , self .half_pad_hidden_size_per_attention_head ))
235
+ bias_padded = torch .cat ([first_half_padded , second_half_padded ], dim = 2 )
236
+ bias_final = bias_padded .reshape (- 1 )
237
+ return bias_final
238
+
239
+ def pad_qkv_weight (self , data ):
240
+ qkv_weight_first_half = data .reshape (
241
+ - 1 , 3 , self .origin_hidden_size_per_attention_head , self .hidden_size
242
+ )[:, :, :self .half_origin_hidden_size_per_attention_head , :]
243
+ qkv_weight_second_half = data .reshape (
244
+ - 1 , 3 , self .origin_hidden_size_per_attention_head , self .hidden_size
245
+ )[:, :, self .half_origin_hidden_size_per_attention_head :, :]
246
+
247
+ qkv_weight_first_half_padded = torch .nn .functional .pad (
248
+ qkv_weight_first_half ,
249
+ (0 , 0 , 0 , self .half_pad_hidden_size_per_attention_head ))
250
+ qkv_weight_second_half_padded = torch .nn .functional .pad (
251
+ qkv_weight_second_half ,
252
+ (0 , 0 , 0 , self .half_pad_hidden_size_per_attention_head ))
253
+ qkv_weight_padded = torch .cat (
254
+ [qkv_weight_first_half_padded , qkv_weight_second_half_padded ],
255
+ dim = 2 )
256
+ qkv_weight_final = qkv_weight_padded .reshape (- 1 , self .hidden_size )
257
+ return qkv_weight_final
258
+
259
+ def pad_proj_weight (self , data ):
260
+ out_weight = torch .nn .functional .pad (
261
+ data .reshape (self .hidden_size , - 1 ,
262
+ self .half_origin_hidden_size_per_attention_head ),
263
+ (0 , self .half_pad_hidden_size_per_attention_head , 0 , 0 )).reshape (
264
+ self .hidden_size , - 1 )
265
+ return out_weight
266
+
267
+ def load_weights (self , weights : Iterable [Tuple [str ,
268
+ torch .Tensor ]]) -> Set [str ]:
269
+ stacked_params_mapping = [
270
+ # (param_name, shard_name, shard_id)
271
+ ("qkv_proj" , "q_proj" , "q" ),
272
+ ("qkv_proj" , "k_proj" , "k" ),
273
+ ("qkv_proj" , "v_proj" , "v" ),
274
+ ]
275
+ params_dict = dict (self .named_parameters (remove_duplicate = False ))
276
+ loaded_params : Set [str ] = set ()
277
+ for name , loaded_weight in weights :
278
+ for (param_name , weight_name , shard_id ) in stacked_params_mapping :
279
+ if weight_name not in name :
280
+ continue
281
+ name = name .replace (weight_name , param_name )
282
+
283
+ param = params_dict [name ]
284
+ weight_loader = param .weight_loader
285
+ weight_loader (param , loaded_weight , shard_id )
286
+ break
287
+ else :
288
+ param = params_dict [name ]
289
+ weight_loader = getattr (param , "weight_loader" ,
290
+ default_weight_loader )
291
+ weight_loader (param , loaded_weight )
292
+ if ("attn.proj.weight" in name ) and self .enable_pad :
293
+ param .data = self .pad_proj_weight (param .data )
294
+ if ("attn.qkv.weight" in name ) and self .enable_pad :
295
+ param .data = self .pad_qkv_weight (param .data )
296
+ if ("attn.qkv.bias" in name ) and self .enable_pad :
297
+ param .data = self .pad_qkv_bias (param .data )
298
+ loaded_params .add (name )
299
+ return loaded_params
300
+
301
+ def forward (
302
+ self ,
303
+ x : torch .Tensor ,
304
+ grid_thw : torch .Tensor ,
305
+ ) -> torch .Tensor :
306
+ # compute cu_seqlens
307
+ cu_seqlens = torch .repeat_interleave (grid_thw [:, 1 ] * grid_thw [:, 2 ],
308
+ grid_thw [:,
309
+ 0 ]).cpu ().to (torch .int32 )
310
+
311
+ # patchify
312
+ x = self .patch_embed (x )
313
+
314
+ # compute position embedding
315
+ rotary_pos_emb = self .rot_pos_emb (grid_thw )
316
+
317
+ # windows attention
318
+ window_index , cu_window_seqlens = self .get_window_index (grid_thw )
319
+ cu_window_seqlens = torch .tensor (
320
+ cu_window_seqlens ,
321
+ device = x .device ,
322
+ dtype = grid_thw .dtype if torch .jit .is_tracing () else torch .int32 )
323
+ cu_window_seqlens = torch .unique_consecutive (cu_window_seqlens )
324
+ cu_window_seqlens = torch .diff (cu_window_seqlens ).cpu ().to (torch .int32 )
325
+ seq_len , _ = x .size ()
326
+ x = x .reshape (seq_len // self .spatial_merge_unit ,
327
+ self .spatial_merge_unit , - 1 )
328
+ x = x [window_index , :, :]
329
+ x = x .reshape (seq_len , - 1 )
330
+ rotary_pos_emb = rotary_pos_emb .reshape (
331
+ seq_len // self .spatial_merge_unit , self .spatial_merge_unit , - 1 )
332
+ rotary_pos_emb = rotary_pos_emb [window_index , :, :]
333
+ rotary_pos_emb = rotary_pos_emb .reshape (seq_len , - 1 )
334
+
335
+ cos , sin = self .cal_cos_sin (rotary_pos_emb )
336
+
337
+ # transformers
338
+ x = x .unsqueeze (1 )
339
+ for layer_num , blk in enumerate (self .blocks ):
340
+ if layer_num in self .fullatt_block_indexes :
341
+ cu_seqlens_now = cu_seqlens
342
+ else :
343
+ cu_seqlens_now = cu_window_seqlens
344
+ x = blk (x , cu_seqlens = cu_seqlens_now , cos = cos , sin = sin )
345
+
346
+ # adapter
347
+ x = self .merger (x )
348
+ reverse_indices = torch .argsort (window_index )
349
+ x = x [reverse_indices , :]
350
+ return x
351
+
352
+
353
+ @MULTIMODAL_REGISTRY .register_processor (
354
+ Qwen2_5_VLMultiModalProcessor ,
355
+ info = Qwen2_5_VLProcessingInfo ,
356
+ dummy_inputs = Qwen2_5_VLDummyInputsBuilder )
357
+ class AscendQwen2_5_VLForConditionalGeneration (
358
+ Qwen2_5_VLForConditionalGeneration ):
359
+
360
+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
361
+ super ().__init__ (vllm_config = vllm_config , prefix = prefix )
362
+ config : Qwen2_5_VLConfig = vllm_config .model_config .hf_config
363
+ quant_config = vllm_config .quant_config
364
+ self .visual = AscendQwen2_5_VisionTransformer (
365
+ vision_config = config .vision_config ,
366
+ norm_eps = getattr (config , "rms_norm_eps" , 1e-6 ),
367
+ quant_config = self ._maybe_ignore_quant_config (quant_config ),
368
+ prefix = maybe_prefix (prefix , "visual" ),
369
+ )
0 commit comments