|
29 | 29 |
|
30 | 30 | import torch
|
31 | 31 | import torch.distributed as dist
|
32 |
| -import torch_npu |
| 32 | +import torch_npu # noqa: F401 |
33 | 33 | import vllm.envs as envs
|
34 | 34 | from torch import nn
|
35 | 35 | from transformers import PretrainedConfig
|
|
40 | 40 | get_tp_group, tensor_model_parallel_all_reduce)
|
41 | 41 | from vllm.distributed.parallel_state import get_dp_group
|
42 | 42 | from vllm.forward_context import get_forward_context
|
43 |
| -from vllm.model_executor.layers.activation import SiluAndMul |
44 | 43 | from vllm.model_executor.layers.layernorm import RMSNorm
|
45 | 44 | from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
46 |
| - MergedColumnParallelLinear, |
47 | 45 | ReplicatedLinear,
|
48 |
| - RowParallelLinear, |
49 |
| - UnquantizedLinearMethod) |
| 46 | + RowParallelLinear) |
50 | 47 | from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
51 | 48 | from vllm.model_executor.layers.quantization import QuantizationConfig
|
52 | 49 | from vllm.model_executor.layers.rotary_embedding import get_rope
|
|
67 | 64 |
|
68 | 65 | import vllm_ascend.envs as envs_ascend
|
69 | 66 | from vllm_ascend.ascend_config import get_ascend_config
|
| 67 | +from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MLP |
70 | 68 | from vllm_ascend.multistream.base import MSEventKey
|
71 | 69 | from vllm_ascend.multistream.context import (
|
72 | 70 | advance_step_multistream_layer_context, get_multistream_comm_context,
|
|
78 | 76 | make_multistream_metadata_ds)
|
79 | 77 | from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
80 | 78 | from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
81 |
| -from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod |
82 | 79 | from vllm_ascend.utils import dispose_tensor
|
83 | 80 |
|
84 | 81 | VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO
|
85 | 82 | VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
86 | 83 |
|
87 | 84 |
|
88 |
| -class CustomDeepseekDBOMLP(nn.Module): |
89 |
| - |
90 |
| - def __init__( |
91 |
| - self, |
92 |
| - hidden_size: int, |
93 |
| - intermediate_size: int, |
94 |
| - hidden_act: str, |
95 |
| - quant_config: Optional[QuantizationConfig] = None, |
96 |
| - reduce_results: bool = True, |
97 |
| - prefix: str = "", |
98 |
| - ) -> None: |
99 |
| - super().__init__() |
100 |
| - self.gate_up_proj = MergedColumnParallelLinear( |
101 |
| - hidden_size, [intermediate_size] * 2, |
102 |
| - bias=False, |
103 |
| - quant_config=quant_config, |
104 |
| - prefix=f"{prefix}.gate_up_proj") |
105 |
| - self.down_proj = RowParallelLinear(intermediate_size, |
106 |
| - hidden_size, |
107 |
| - bias=False, |
108 |
| - quant_config=quant_config, |
109 |
| - reduce_results=reduce_results, |
110 |
| - prefix=f"{prefix}.down_proj") |
111 |
| - if hidden_act != "silu": |
112 |
| - raise ValueError(f"Unsupported activation: {hidden_act}. " |
113 |
| - "Only silu is supported for now.") |
114 |
| - self.act_fn = SiluAndMul() |
115 |
| - |
116 |
| - # NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant |
117 |
| - self.is_dynamic_quant = not isinstance( |
118 |
| - self.gate_up_proj.quant_method, |
119 |
| - UnquantizedLinearMethod) and isinstance( |
120 |
| - self.gate_up_proj.quant_method.quant_method, |
121 |
| - AscendW8A8DynamicLinearMethod) |
122 |
| - |
123 |
| - def forward(self, x): |
124 |
| - if self.is_dynamic_quant: |
125 |
| - x, dynamic_scale = torch_npu.npu_dynamic_quant(x) |
126 |
| - x = torch_npu.npu_quant_matmul( |
127 |
| - x, |
128 |
| - self.gate_up_proj.weight, |
129 |
| - self.gate_up_proj.weight_scale, |
130 |
| - output_dtype=torch.int32, |
131 |
| - ) |
132 |
| - x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant( |
133 |
| - x=x, |
134 |
| - weight_scale=self.gate_up_proj.weight_scale_fp32, |
135 |
| - activation_scale=dynamic_scale, |
136 |
| - bias=None, |
137 |
| - quant_scale=None, |
138 |
| - quant_offset=None, |
139 |
| - group_index=None, |
140 |
| - activate_left=True, |
141 |
| - quant_mode=1) |
142 |
| - x = torch_npu.npu_quant_matmul( |
143 |
| - x, |
144 |
| - self.down_proj.weight, |
145 |
| - self.down_proj.weight_scale, |
146 |
| - pertoken_scale=dynamic_scale, |
147 |
| - output_dtype=torch.bfloat16, |
148 |
| - ) |
149 |
| - if self.down_proj.reduce_results and self.down_proj.tp_size > 1: |
150 |
| - x = tensor_model_parallel_all_reduce(x) |
151 |
| - return x |
152 |
| - gate_up, _ = self.gate_up_proj(x) |
153 |
| - x = self.act_fn(gate_up) |
154 |
| - x, _ = self.down_proj(x) |
155 |
| - return x |
| 85 | +class CustomDeepseekDBOMLP(CustomDeepseekV2MLP): |
156 | 86 |
|
157 | 87 | def _forward_ms_mlp(self, x):
|
158 | 88 | current_ms_metadata = get_multistream_comm_context()
|
159 | 89 | assert current_ms_metadata is not None
|
160 |
| - if self.is_dynamic_quant: |
161 |
| - x, dynamic_scale = torch_npu.npu_dynamic_quant(x) |
162 |
| - x = torch_npu.npu_quant_matmul( |
163 |
| - x, |
164 |
| - self.gate_up_proj.weight, |
165 |
| - self.gate_up_proj.weight_scale, |
166 |
| - output_dtype=torch.int32, |
167 |
| - ) |
168 |
| - x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant( |
169 |
| - x=x, |
170 |
| - weight_scale=self.gate_up_proj.weight_scale_fp32, |
171 |
| - activation_scale=dynamic_scale, |
172 |
| - bias=None, |
173 |
| - quant_scale=None, |
174 |
| - quant_offset=None, |
175 |
| - group_index=None, |
176 |
| - activate_left=True, |
177 |
| - quant_mode=1) |
178 |
| - x = torch_npu.npu_quant_matmul( |
179 |
| - x, |
180 |
| - self.down_proj.weight, |
181 |
| - self.down_proj.weight_scale, |
182 |
| - pertoken_scale=dynamic_scale, |
183 |
| - output_dtype=torch.bfloat16, |
184 |
| - ) |
185 |
| - if self.down_proj.reduce_results and self.down_proj.tp_size > 1: |
186 |
| - current_ms_metadata.before_comm_event.record() |
187 |
| - with torch.npu.stream(current_ms_metadata.comm_stream): |
188 |
| - current_ms_metadata.before_comm_event.wait() |
189 |
| - x = tensor_model_parallel_all_reduce(x) |
190 |
| - current_ms_metadata.after_comm_event.record() |
191 |
| - return x |
192 | 90 | gate_up, _ = self.gate_up_proj(x)
|
193 | 91 | x = self.act_fn(gate_up)
|
194 | 92 | current_ms_metadata.before_comm_event.record()
|
|
0 commit comments