|
| 1 | +# |
| 2 | +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. |
| 3 | +# Adapted from vllm/model_executor/models/qwen2_5_vl.py |
| 4 | +# Copyright 2023 The vLLM team. |
| 5 | +# |
| 6 | +# This file is a part of the vllm-ascend project. |
| 7 | +# |
| 8 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 9 | +# you may not use this file except in compliance with the License. |
| 10 | +# You may obtain a copy of the License at |
| 11 | +# |
| 12 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 13 | +# |
| 14 | +# Unless required by applicable law or agreed to in writing, software |
| 15 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 16 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 17 | +# See the License for the specific language governing permissions and |
| 18 | +# limitations under the License. |
| 19 | + |
| 20 | +from functools import partial |
| 21 | +from typing import Callable, Optional |
| 22 | + |
| 23 | +import torch |
| 24 | +import torch.nn as nn |
| 25 | +import torch.nn.functional as F |
| 26 | +import torch_npu |
| 27 | +from einops import rearrange |
| 28 | +from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( |
| 29 | + Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) |
| 30 | +from vllm.config import VllmConfig |
| 31 | +from vllm.distributed import parallel_state |
| 32 | +from vllm.distributed import utils as dist_utils |
| 33 | +from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY |
| 34 | +from vllm.model_executor.layers.layernorm import RMSNorm |
| 35 | +from vllm.model_executor.layers.quantization import QuantizationConfig |
| 36 | +from vllm.model_executor.models.qwen2_5_vl import ( |
| 37 | + Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed, |
| 38 | + Qwen2_5_VisionTransformer, Qwen2_5_VLDummyInputsBuilder, |
| 39 | + Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLMultiModalProcessor, |
| 40 | + Qwen2_5_VLProcessingInfo) |
| 41 | +from vllm.model_executor.models.utils import maybe_prefix |
| 42 | +from vllm.multimodal import MULTIMODAL_REGISTRY |
| 43 | + |
| 44 | + |
| 45 | +class AscendQwen2_5_VisionAttention_Without_Padding(Qwen2_5_VisionAttention): |
| 46 | + |
| 47 | + def __init__( |
| 48 | + self, |
| 49 | + embed_dim: int, |
| 50 | + num_heads: int, |
| 51 | + projection_size: int, |
| 52 | + quant_config: Optional[QuantizationConfig] = None, |
| 53 | + prefix: str = "", |
| 54 | + ) -> None: |
| 55 | + super().__init__( |
| 56 | + embed_dim, |
| 57 | + num_heads, |
| 58 | + projection_size, |
| 59 | + quant_config, |
| 60 | + prefix, |
| 61 | + ) |
| 62 | + self.embed_dim = embed_dim |
| 63 | + self.hidden_size_per_attention_head = dist_utils.divide( |
| 64 | + projection_size, num_heads) |
| 65 | + |
| 66 | + def forward( |
| 67 | + self, |
| 68 | + x: torch.Tensor, |
| 69 | + cu_seqlens: torch.Tensor, |
| 70 | + cos: torch.Tensor, |
| 71 | + sin: torch.Tensor, |
| 72 | + ) -> torch.Tensor: |
| 73 | + # [s, b, c] --> [s, b, head * 3 * head_dim] |
| 74 | + x, _ = self.qkv(x) |
| 75 | + |
| 76 | + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] |
| 77 | + q, k, v = self.split_qkv(x) |
| 78 | + batch_size = q.shape[1] |
| 79 | + |
| 80 | + q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() |
| 81 | + for x in (q, k, v)) |
| 82 | + q = torch_npu.npu_rotary_mul(q, cos, sin) |
| 83 | + k = torch_npu.npu_rotary_mul(k, cos, sin) |
| 84 | + |
| 85 | + q, k, v = [ |
| 86 | + rearrange(x, "b s h d -> (b s) h d").contiguous() |
| 87 | + for x in (q, k, v) |
| 88 | + ] |
| 89 | + |
| 90 | + context_layer = torch.torch.empty_like(q) |
| 91 | + |
| 92 | + # operator requires pta version >= 2.5.1.dev20250226 |
| 93 | + torch_npu._npu_flash_attention_unpad( |
| 94 | + query=q, |
| 95 | + key=k, |
| 96 | + value=v, |
| 97 | + seq_len=cu_seqlens, |
| 98 | + scale_value=self.hidden_size_per_attention_head**-0.5, |
| 99 | + num_heads=self.num_attention_heads_per_partition, |
| 100 | + num_kv_heads=self.num_attention_heads_per_partition, |
| 101 | + out=context_layer) |
| 102 | + |
| 103 | + context_layer = rearrange(context_layer, |
| 104 | + "(b s) h d -> s b (h d)", |
| 105 | + b=batch_size).contiguous() |
| 106 | + |
| 107 | + output, _ = self.proj(context_layer) |
| 108 | + return output |
| 109 | + |
| 110 | + |
| 111 | +class AscendQwen2_5_VisionBlock_Without_Padding(Qwen2_5_VisionBlock): |
| 112 | + |
| 113 | + def __init__( |
| 114 | + self, |
| 115 | + dim: int, |
| 116 | + num_heads: int, |
| 117 | + mlp_hidden_dim: int, |
| 118 | + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, |
| 119 | + norm_layer: Optional[Callable[[int], nn.Module]] = None, |
| 120 | + quant_config: Optional[QuantizationConfig] = None, |
| 121 | + prefix: str = "", |
| 122 | + ) -> None: |
| 123 | + super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer, |
| 124 | + quant_config, prefix) |
| 125 | + self.attn = AscendQwen2_5_VisionAttention_Without_Padding( |
| 126 | + embed_dim=dim, |
| 127 | + num_heads=num_heads, |
| 128 | + projection_size=dim, |
| 129 | + quant_config=quant_config, |
| 130 | + prefix=f"{prefix}.attn") |
| 131 | + |
| 132 | + def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, |
| 133 | + cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| 134 | + x = x + self.attn( |
| 135 | + self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin) |
| 136 | + |
| 137 | + x = x + self.mlp(self.norm2(x)) |
| 138 | + return x |
| 139 | + |
| 140 | + |
| 141 | +class AscendQwen2_5_VisionPatchEmbed_Without_Padding(Qwen2_5_VisionPatchEmbed): |
| 142 | + |
| 143 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 144 | + x = x.matmul( |
| 145 | + self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1)) |
| 146 | + return x |
| 147 | + |
| 148 | + |
| 149 | +class AscendQwen2_5_VisionTransformer_Without_Padding(Qwen2_5_VisionTransformer |
| 150 | + ): |
| 151 | + |
| 152 | + def __init__( |
| 153 | + self, |
| 154 | + vision_config: Qwen2_5_VLVisionConfig, |
| 155 | + norm_eps: float = 1e-6, |
| 156 | + quant_config: Optional[QuantizationConfig] = None, |
| 157 | + prefix: str = "", |
| 158 | + interleaved=False, |
| 159 | + ) -> None: |
| 160 | + super().__init__(vision_config, norm_eps, quant_config, prefix) |
| 161 | + norm_layer = partial(RMSNorm, eps=norm_eps) |
| 162 | + self.interleaved = interleaved |
| 163 | + self.patch_embed = AscendQwen2_5_VisionPatchEmbed_Without_Padding( |
| 164 | + patch_size=vision_config.patch_size, |
| 165 | + temporal_patch_size=vision_config.temporal_patch_size, |
| 166 | + in_channels=vision_config.in_channels, |
| 167 | + hidden_size=self.hidden_size, |
| 168 | + ) |
| 169 | + self.blocks = nn.ModuleList([ |
| 170 | + AscendQwen2_5_VisionBlock_Without_Padding( |
| 171 | + dim=self.hidden_size, |
| 172 | + num_heads=self.num_heads, |
| 173 | + mlp_hidden_dim=vision_config.intermediate_size, |
| 174 | + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], |
| 175 | + norm_layer=norm_layer, |
| 176 | + quant_config=quant_config, |
| 177 | + prefix=f"{prefix}.blocks.{layer_idx}") |
| 178 | + for layer_idx in range(vision_config.depth) |
| 179 | + ]) |
| 180 | + self.tp_size = parallel_state.get_tensor_model_parallel_world_size() |
| 181 | + self.tp_rank = parallel_state.get_tensor_model_parallel_rank() |
| 182 | + self.hidden_size_per_attention_head = dist_utils.divide( |
| 183 | + self.hidden_size, self.num_heads) |
| 184 | + |
| 185 | + def cal_cos_sin(self, rotary_pos_emb): |
| 186 | + cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2] |
| 187 | + sin = rotary_pos_emb.sin() |
| 188 | + |
| 189 | + if not self.interleaved: |
| 190 | + cos_new = torch.cat((cos, cos), dim=-1) |
| 191 | + sin_new = torch.cat((sin, sin), dim=-1) |
| 192 | + else: |
| 193 | + cos_new = rearrange(torch.stack((cos, cos), dim=-1), |
| 194 | + "... d two -> ...(d two)", |
| 195 | + two=2) |
| 196 | + sin_new = rearrange(torch.stack((sin, sin), dim=-1), |
| 197 | + "... d two -> ...(d two)", |
| 198 | + two=2) |
| 199 | + cos_new = cos_new.reshape(1, -1, 1, |
| 200 | + self.hidden_size_per_attention_head) |
| 201 | + sin_new = sin_new.reshape(1, -1, 1, |
| 202 | + self.hidden_size_per_attention_head) |
| 203 | + return cos_new, sin_new |
| 204 | + |
| 205 | + def forward( |
| 206 | + self, |
| 207 | + x: torch.Tensor, |
| 208 | + grid_thw: torch.Tensor, |
| 209 | + ) -> torch.Tensor: |
| 210 | + # compute cu_seqlens |
| 211 | + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], |
| 212 | + grid_thw[:, |
| 213 | + 0]).cpu().to(torch.int32) |
| 214 | + |
| 215 | + # patchify |
| 216 | + x = self.patch_embed(x) |
| 217 | + |
| 218 | + # compute position embedding |
| 219 | + rotary_pos_emb = self.rot_pos_emb(grid_thw) |
| 220 | + |
| 221 | + # windows attention |
| 222 | + window_index, cu_window_seqlens = self.get_window_index(grid_thw) |
| 223 | + cu_window_seqlens = torch.tensor( |
| 224 | + cu_window_seqlens, |
| 225 | + device=x.device, |
| 226 | + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32) |
| 227 | + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) |
| 228 | + cu_window_seqlens = torch.diff(cu_window_seqlens).cpu().to(torch.int32) |
| 229 | + seq_len, _ = x.size() |
| 230 | + x = x.reshape(seq_len // self.spatial_merge_unit, |
| 231 | + self.spatial_merge_unit, -1) |
| 232 | + x = x[window_index, :, :] |
| 233 | + x = x.reshape(seq_len, -1) |
| 234 | + rotary_pos_emb = rotary_pos_emb.reshape( |
| 235 | + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) |
| 236 | + rotary_pos_emb = rotary_pos_emb[window_index, :, :] |
| 237 | + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) |
| 238 | + |
| 239 | + cos, sin = self.cal_cos_sin(rotary_pos_emb) |
| 240 | + |
| 241 | + # transformers |
| 242 | + x = x.unsqueeze(1) |
| 243 | + for layer_num, blk in enumerate(self.blocks): |
| 244 | + if layer_num in self.fullatt_block_indexes: |
| 245 | + cu_seqlens_now = cu_seqlens |
| 246 | + else: |
| 247 | + cu_seqlens_now = cu_window_seqlens |
| 248 | + x = blk(x, cu_seqlens=cu_seqlens_now, cos=cos, sin=sin) |
| 249 | + |
| 250 | + # adapter |
| 251 | + x = self.merger(x) |
| 252 | + reverse_indices = torch.argsort(window_index) |
| 253 | + x = x[reverse_indices, :] |
| 254 | + return x |
| 255 | + |
| 256 | + |
| 257 | +@MULTIMODAL_REGISTRY.register_processor( |
| 258 | + Qwen2_5_VLMultiModalProcessor, |
| 259 | + info=Qwen2_5_VLProcessingInfo, |
| 260 | + dummy_inputs=Qwen2_5_VLDummyInputsBuilder) |
| 261 | +class AscendQwen2_5_VLForConditionalGeneration_Without_Padding( |
| 262 | + Qwen2_5_VLForConditionalGeneration): |
| 263 | + |
| 264 | + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| 265 | + super().__init__(vllm_config=vllm_config, prefix=prefix) |
| 266 | + config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config |
| 267 | + quant_config = vllm_config.quant_config |
| 268 | + self.visual = AscendQwen2_5_VisionTransformer_Without_Padding( |
| 269 | + vision_config=config.vision_config, |
| 270 | + norm_eps=getattr(config, "rms_norm_eps", 1e-6), |
| 271 | + quant_config=self._maybe_ignore_quant_config(quant_config), |
| 272 | + prefix=maybe_prefix(prefix, "visual"), |
| 273 | + ) |
0 commit comments