|
19 | 19 |
|
20 | 20 | import numpy as np
|
21 | 21 | import torch
|
22 |
| -import torch.distributed as dist |
23 | 22 | import torch_npu
|
24 | 23 | from vllm.config import get_current_vllm_config
|
25 |
| -from vllm.distributed import GroupCoordinator, get_ep_group |
| 24 | +from vllm.distributed import get_ep_group |
| 25 | +from vllm.forward_context import get_forward_context |
26 | 26 |
|
27 | 27 | from vllm_ascend.ascend_config import get_ascend_config
|
| 28 | +from vllm_ascend.ascend_forward_context import FusedMoEState |
28 | 29 | from vllm_ascend.ops.fused_moe import select_experts
|
29 |
| -from vllm_ascend.utils import dispose_tensor |
30 |
| - |
31 |
| - |
32 |
| -def apply_mlp(hidden_states: torch.Tensor, |
33 |
| - w1: torch.Tensor, |
34 |
| - w1_scale: torch.Tensor, |
35 |
| - w2: torch.Tensor, |
36 |
| - w2_scale: torch.Tensor, |
37 |
| - w1_scale_bias: torch.Tensor, |
38 |
| - w2_scale_bias: torch.Tensor, |
39 |
| - group_list: torch.Tensor, |
40 |
| - dynamic_scale: torch.Tensor = None, |
41 |
| - group_list_type: int = 1) -> torch.Tensor: |
42 |
| - """ |
43 |
| - apply MLP: gate_up_proj -> swiglu -> down_proj |
44 |
| -
|
45 |
| - Args: |
46 |
| - hidden_states: input hidden states with shape (num_tokens, hidden_size). |
47 |
| - w1: expert weights1 with shape |
48 |
| - (num_experts, hidden_size, intermediate_size * 2) |
49 |
| - w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2) |
50 |
| - w2: expert weights2 with shape |
51 |
| - (num_experts, intermediate_size, hidden_size) |
52 |
| - w2_scale: weights2 scale with shape (num_experts, hidden_size) |
53 |
| - group_list: number of tokens for each expert, follow cumsum mode, and |
54 |
| - with shape (num_experts). |
55 |
| - transpose_weight: |
56 |
| - w1: (num_experts, intermediate_size * 2, hidden_size) -> |
57 |
| - (num_experts, hidden_size, intermediate_size * 2) |
58 |
| - w2: (num_experts, hidden_size, intermediate_size) -> |
59 |
| - (num_experts, intermediate_size, hidden_size) |
60 |
| -
|
61 |
| - Returns: |
62 |
| - hidden_states: output hidden states after MLP. |
63 |
| - """ |
64 |
| - |
65 |
| - if dynamic_scale is None: |
66 |
| - unquantized_hidden_states = hidden_states |
67 |
| - hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( |
68 |
| - hidden_states) |
69 |
| - # Dispose the original unquantized hidden states |
70 |
| - # to save npu memory because they're no longer used. |
71 |
| - dispose_tensor(unquantized_hidden_states) |
72 |
| - else: |
73 |
| - pertoken_scale = dynamic_scale |
74 |
| - |
75 |
| - # gmm1: gate_up_proj |
76 |
| - group_list_type = 1 |
77 |
| - group_list = torch.cat([group_list[:1], torch.diff(group_list, dim=0)]) |
78 |
| - |
79 |
| - hidden_states = torch_npu.npu_grouped_matmul( |
80 |
| - x=[hidden_states], |
81 |
| - weight=[w1], |
82 |
| - scale=[w1_scale], |
83 |
| - bias=[w1_scale_bias], |
84 |
| - per_token_scale=[pertoken_scale], |
85 |
| - split_item=2, |
86 |
| - group_list_type=group_list_type, |
87 |
| - group_type=0, |
88 |
| - group_list=group_list, |
89 |
| - output_dtype=torch.bfloat16)[0] |
90 |
| - |
91 |
| - # act_fn: swiglu |
92 |
| - hidden_states = torch_npu.npu_swiglu(hidden_states) |
93 |
| - hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant( |
94 |
| - hidden_states) |
95 |
| - |
96 |
| - # gmm2: down_proj |
97 |
| - hidden_states = torch_npu.npu_grouped_matmul( |
98 |
| - x=[hidden_states], |
99 |
| - weight=[w2], |
100 |
| - scale=[w2_scale], |
101 |
| - bias=[w2_scale_bias], |
102 |
| - per_token_scale=[swiglu_out_scale], |
103 |
| - split_item=2, |
104 |
| - group_list_type=group_list_type, |
105 |
| - group_type=0, |
106 |
| - group_list=group_list, |
107 |
| - output_dtype=torch.bfloat16)[0] |
108 |
| - return hidden_states |
109 |
| - |
110 |
| - |
111 |
| -# currently expert parallelism implemented with all2all |
112 |
| -# is under-optimized. |
113 |
| -def fused_experts_with_all2all( |
114 |
| - hidden_states: torch.Tensor, |
115 |
| - w1: torch.Tensor, |
116 |
| - w1_scale: torch.Tensor, |
117 |
| - w2: torch.Tensor, |
118 |
| - w2_scale: torch.Tensor, |
119 |
| - w1_scale_bias: torch.Tensor, |
120 |
| - w2_scale_bias: torch.Tensor, |
121 |
| - topk_weights: torch.Tensor, |
122 |
| - topk_ids: torch.Tensor, |
123 |
| - top_k: int, |
124 |
| - expert_map: torch.Tensor = None, |
125 |
| - ep_group: GroupCoordinator = None, |
126 |
| - log2phy: torch.Tensor = None, |
127 |
| - global_redundant_expert_num: int = 0, |
128 |
| -): |
129 |
| - if log2phy: |
130 |
| - topk_ids = log2phy[topk_ids] |
131 |
| - original_shape = hidden_states.shape |
132 |
| - if len(original_shape) == 3: |
133 |
| - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) |
134 |
| - |
135 |
| - num_tokens, _ = hidden_states.shape |
136 |
| - num_experts = w1.shape[0] |
137 |
| - device = hidden_states.device |
138 |
| - |
139 |
| - if expert_map is not None: |
140 |
| - global_num_experts = len(expert_map) + global_redundant_expert_num |
141 |
| - local_num_experts = global_num_experts // ep_group.world_size |
142 |
| - row_idx_len = num_tokens * top_k |
143 |
| - row_idx = (torch.arange(0, |
144 |
| - row_idx_len, |
145 |
| - dtype=torch.int32, |
146 |
| - device=device).view(top_k, -1).permute( |
147 |
| - 1, 0).contiguous()) |
148 |
| - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( |
149 |
| - hidden_states, |
150 |
| - row_idx=row_idx, |
151 |
| - expert_idx=topk_ids, |
152 |
| - active_num=num_tokens) |
153 |
| - |
154 |
| - global_expert_tokens = torch.bincount(expanded_expert_idx, |
155 |
| - minlength=global_num_experts) |
156 |
| - scatter_sizes = global_expert_tokens.view(ep_group.world_size, |
157 |
| - -1).sum(-1) |
158 |
| - |
159 |
| - gather_sizes = torch.empty_like(scatter_sizes) |
160 |
| - dist.all_to_all_single(gather_sizes, |
161 |
| - scatter_sizes, |
162 |
| - group=ep_group.device_group) |
163 |
| - scatter_size_list = scatter_sizes.cpu().tolist() |
164 |
| - gather_size_list = gather_sizes.cpu().tolist() |
165 |
| - |
166 |
| - expanded_expert_idx = expanded_expert_idx % local_num_experts |
167 |
| - hidden_states = ep_group.all_to_all(hidden_states, 0, 0, |
168 |
| - scatter_size_list, |
169 |
| - gather_size_list) |
170 |
| - local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0, |
171 |
| - scatter_size_list, |
172 |
| - gather_size_list) |
173 |
| - |
174 |
| - sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx) |
175 |
| - |
176 |
| - expert_tokens = torch_npu.npu_moe_compute_expert_tokens( |
177 |
| - sorted_local_expert_idx, local_num_experts).to(torch.int64) |
178 |
| - |
179 |
| - hidden_states = hidden_states[sorted_idx] |
180 |
| - group_list_type = 0 |
181 |
| - else: |
182 |
| - row_idx_len = num_tokens * top_k |
183 |
| - row_idx = torch.arange(0, |
184 |
| - row_idx_len, |
185 |
| - dtype=torch.int32, |
186 |
| - device=topk_weights.device).view( |
187 |
| - top_k, -1).permute(1, 0).contiguous() |
188 |
| - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( |
189 |
| - hidden_states, |
190 |
| - row_idx=row_idx, |
191 |
| - expert_idx=topk_ids, |
192 |
| - active_num=num_tokens) |
193 |
| - |
194 |
| - expert_tokens = torch_npu.npu_moe_compute_expert_tokens( |
195 |
| - expanded_expert_idx, num_experts) |
196 |
| - expert_tokens = expert_tokens.to(torch.int64) |
197 |
| - group_list_type = 0 |
198 |
| - |
199 |
| - # `hidden_states` will be disposed in the `apply_mlp` function |
200 |
| - hidden_states = apply_mlp(hidden_states, |
201 |
| - w1, |
202 |
| - w1_scale, |
203 |
| - w2, |
204 |
| - w2_scale, |
205 |
| - w1_scale_bias, |
206 |
| - w2_scale_bias, |
207 |
| - expert_tokens, |
208 |
| - group_list_type=group_list_type) |
209 |
| - |
210 |
| - if expert_map is not None: |
211 |
| - resorted_idx = torch.argsort(sorted_idx) |
212 |
| - hidden_states = hidden_states[resorted_idx] |
213 |
| - hidden_states = ep_group.all_to_all(hidden_states, 0, 0, |
214 |
| - gather_size_list, |
215 |
| - scatter_size_list) |
216 |
| - |
217 |
| - final_hidden_states = torch_npu.npu_moe_finalize_routing( |
218 |
| - hidden_states, |
219 |
| - skip1=None, |
220 |
| - skip2=None, |
221 |
| - bias=None, |
222 |
| - scales=topk_weights, |
223 |
| - expanded_src_to_dst_row=expanded_row_idx, |
224 |
| - export_for_source_row=topk_ids, |
225 |
| - ) |
226 |
| - else: |
227 |
| - # TODO: Reorder device memory 2 times here, replace the current |
228 |
| - # implementation here when suitable operators become available. |
229 |
| - final_hidden_states = torch_npu.npu_moe_finalize_routing( |
230 |
| - hidden_states, |
231 |
| - skip1=None, |
232 |
| - skip2=None, |
233 |
| - bias=None, |
234 |
| - scales=topk_weights, |
235 |
| - expanded_src_to_dst_row=expanded_row_idx, |
236 |
| - export_for_source_row=topk_ids, |
237 |
| - ) |
238 |
| - if len(original_shape) == 3: |
239 |
| - final_hidden_states = final_hidden_states.view(original_shape) |
240 |
| - return final_hidden_states |
| 30 | +from vllm_ascend.quantization.w8a8_dynamic import (fused_experts_with_all2all, |
| 31 | + fused_experts_with_mc2) |
241 | 32 |
|
242 | 33 |
|
243 | 34 | class AscendW4A8DynamicLinearMethod:
|
@@ -483,26 +274,46 @@ def apply(
|
483 | 274 |
|
484 | 275 | topk_weights = topk_weights.to(x.dtype)
|
485 | 276 |
|
486 |
| - # The current implementation of deepseek moe splits hidden_states |
487 |
| - # according to tp_size before they are feed into fused_moe module. |
488 |
| - # Therefore, all2all is needed no matter how dp/tp is set so as to |
489 |
| - # dispatch/combine tokens. |
490 |
| - return fused_experts_with_all2all( |
491 |
| - hidden_states=x, |
492 |
| - w1=layer.w13_weight, |
493 |
| - w2=layer.w2_weight, |
494 |
| - w1_scale=layer.w13_weight_scale_second, |
495 |
| - w2_scale=layer.w2_weight_scale_second, |
496 |
| - w1_scale_bias=layer.w13_scale_bias, |
497 |
| - w2_scale_bias=layer.w2_scale_bias, |
498 |
| - topk_weights=topk_weights, |
499 |
| - topk_ids=topk_ids, |
500 |
| - top_k=top_k, |
501 |
| - expert_map=expert_map, |
502 |
| - ep_group=self.ep_group, |
503 |
| - log2phy=log2phy, |
504 |
| - global_redundant_expert_num=global_redundant_expert_num, |
505 |
| - ) |
| 277 | + fused_moe_state = get_forward_context().fused_moe_state |
| 278 | + if fused_moe_state == FusedMoEState.MC2: |
| 279 | + return fused_experts_with_mc2( |
| 280 | + hidden_states=x, |
| 281 | + w1=layer.w13_weight, |
| 282 | + w2=layer.w2_weight, |
| 283 | + w1_scale=layer.w13_weight_scale_second, |
| 284 | + w2_scale=layer.w2_weight_scale_second, |
| 285 | + w1_scale_bias=layer.w13_scale_bias, |
| 286 | + w2_scale_bias=layer.w2_scale_bias, |
| 287 | + topk_weights=topk_weights, |
| 288 | + topk_ids=topk_ids, |
| 289 | + top_k=top_k, |
| 290 | + expert_map=expert_map, |
| 291 | + moe_all_to_all_group_name=self.moe_all_to_all_group_name, |
| 292 | + log2phy=log2phy, |
| 293 | + global_redundant_expert_num=global_redundant_expert_num, |
| 294 | + shared_experts=shared_experts, |
| 295 | + is_torchair=self.torchair_graph_enabled) |
| 296 | + else: |
| 297 | + # The current implementation of deepseek moe splits hidden_states |
| 298 | + # according to tp_size before they are feed into fused_moe module. |
| 299 | + # Therefore, all2all is needed no matter how dp/tp is set so as to |
| 300 | + # dispatch/combine tokens. |
| 301 | + return fused_experts_with_all2all( |
| 302 | + hidden_states=x, |
| 303 | + w1=layer.w13_weight, |
| 304 | + w2=layer.w2_weight, |
| 305 | + w1_scale=layer.w13_weight_scale_second, |
| 306 | + w2_scale=layer.w2_weight_scale_second, |
| 307 | + w1_scale_bias=layer.w13_scale_bias, |
| 308 | + w2_scale_bias=layer.w2_scale_bias, |
| 309 | + topk_weights=topk_weights, |
| 310 | + topk_ids=topk_ids, |
| 311 | + top_k=top_k, |
| 312 | + expert_map=expert_map, |
| 313 | + ep_group=self.ep_group, |
| 314 | + log2phy=log2phy, |
| 315 | + global_redundant_expert_num=global_redundant_expert_num, |
| 316 | + ) |
506 | 317 |
|
507 | 318 | def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
|
508 | 319 | group_num, k, n = weight.shape
|
|
0 commit comments