|
15 | 15 | # This file is a part of the vllm-ascend project.
|
16 | 16 | # Adapted from vllm/tests/kernels/test_moe.py
|
17 | 17 |
|
18 |
| -from typing import Callable, Optional |
| 18 | +from typing import Callable, List, Optional |
19 | 19 |
|
20 | 20 | import torch
|
21 | 21 | import torch.distributed as dist
|
|
37 | 37 |
|
38 | 38 | VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
39 | 39 | USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
|
| 40 | +MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER |
| 41 | + |
| 42 | + |
| 43 | +def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, |
| 44 | + max_row_per_ep_rank: int, num_tokens: int, |
| 45 | + top_k: int) -> tuple[torch.Tensor, torch.Tensor]: |
| 46 | + original_total_elements = num_tokens * top_k |
| 47 | + device = topk_ids.device |
| 48 | + original_dtype = topk_ids.dtype |
| 49 | + |
| 50 | + if original_total_elements == 0: |
| 51 | + output_len = ep_size * max_row_per_ep_rank |
| 52 | + topk_ids_pad = torch.full((output_len, ), |
| 53 | + expert_num, |
| 54 | + dtype=original_dtype, |
| 55 | + device=device) |
| 56 | + unpad_indices = torch.full((original_total_elements, ), |
| 57 | + -1, |
| 58 | + dtype=torch.long, |
| 59 | + device=device) |
| 60 | + return topk_ids_pad, unpad_indices |
| 61 | + |
| 62 | + experts_per_ep_rank_val = expert_num // ep_size |
| 63 | + if experts_per_ep_rank_val == 0: |
| 64 | + raise ValueError( |
| 65 | + "expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. " |
| 66 | + "Ensure expert_num >= ep_size.") |
| 67 | + |
| 68 | + assigned_ep_rank = (topk_ids.float() / |
| 69 | + experts_per_ep_rank_val).to(original_dtype) |
| 70 | + indices_arange = torch.arange(topk_ids.shape[0], device=device) |
| 71 | + |
| 72 | + is_new_segment = torch.cat((torch.tensor([True], device=device), |
| 73 | + assigned_ep_rank[1:] != assigned_ep_rank[:-1])) |
| 74 | + temp_start_markers = torch.full_like(indices_arange, |
| 75 | + -1, |
| 76 | + dtype=indices_arange.dtype) |
| 77 | + temp_start_markers[is_new_segment] = indices_arange[is_new_segment] |
| 78 | + start_offset_for_each_token = torch.cummax(temp_start_markers, dim=0)[0] |
| 79 | + token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token |
| 80 | + is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank |
| 81 | + cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long) |
| 82 | + indices_in_rec_cond_list_for_all = cumsum_kept - 1 |
| 83 | + unpad_indices = torch.where( |
| 84 | + is_kept_mask, indices_in_rec_cond_list_for_all, |
| 85 | + torch.tensor(-1, device=device, dtype=torch.long)) |
| 86 | + output_len = ep_size * max_row_per_ep_rank |
| 87 | + topk_ids_pad = torch.full((output_len, ), |
| 88 | + expert_num, |
| 89 | + dtype=original_dtype, |
| 90 | + device=device) |
| 91 | + if topk_ids.shape[0] > 0: |
| 92 | + all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx |
| 93 | + temp_pad_buffer = torch.full((output_len + 1, ), |
| 94 | + expert_num, |
| 95 | + dtype=original_dtype, |
| 96 | + device=device) |
| 97 | + output_len_tensor = torch.tensor(output_len, |
| 98 | + dtype=torch.long, |
| 99 | + device=device) |
| 100 | + scatter_indices = torch.where(is_kept_mask, all_destination_indices, |
| 101 | + output_len_tensor) |
| 102 | + temp_pad_buffer.scatter_(0, scatter_indices, topk_ids) |
| 103 | + topk_ids_pad = temp_pad_buffer[:output_len] |
| 104 | + return topk_ids_pad, unpad_indices |
40 | 105 |
|
41 | 106 |
|
42 | 107 | def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
@@ -146,8 +211,62 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
146 | 211 | return hidden_states
|
147 | 212 |
|
148 | 213 |
|
149 |
| -# currently expert parallelism implemented with all2all |
150 |
| -# is under-optimized. |
| 214 | +def apply_mlp(hidden_states_wrapper: List[torch.Tensor], |
| 215 | + w1: torch.Tensor, |
| 216 | + w2: torch.Tensor, |
| 217 | + group_list: torch.Tensor, |
| 218 | + group_list_type: int = 1) -> torch.Tensor: |
| 219 | + """ |
| 220 | + apply MLP: gate_up_proj -> swiglu -> down_proj |
| 221 | +
|
| 222 | + Args: |
| 223 | + hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size). |
| 224 | + w1: expert weights1 with shape |
| 225 | + (num_experts, hidden_size, intermediate_size * 2) |
| 226 | + w2: expert weights2 with shape |
| 227 | + (num_experts, intermediate_size, hidden_size) |
| 228 | + group_list: number of tokens for each expert, follow cumsum mode, and |
| 229 | + with shape (num_experts). |
| 230 | + transpose_weight: |
| 231 | + w1: (num_experts, intermediate_size * 2, hidden_size) -> |
| 232 | + (num_experts, hidden_size, intermediate_size * 2) |
| 233 | + w2: (num_experts, hidden_size, intermediate_size) -> |
| 234 | + (num_experts, intermediate_size, hidden_size) |
| 235 | +
|
| 236 | + Returns: |
| 237 | + hidden_states: output hidden states after MLP. |
| 238 | + """ |
| 239 | + |
| 240 | + assert len(hidden_states_wrapper) == 1 |
| 241 | + hidden_states = hidden_states_wrapper.pop() |
| 242 | + |
| 243 | + w1 = w1.transpose(1, 2) |
| 244 | + hidden_states = torch_npu.npu_grouped_matmul( |
| 245 | + x=[hidden_states], |
| 246 | + weight=[w1], |
| 247 | + split_item=2, |
| 248 | + group_list_type=group_list_type, |
| 249 | + group_type=0, |
| 250 | + group_list=group_list, |
| 251 | + ) |
| 252 | + |
| 253 | + hidden_states = torch.cat(hidden_states, dim=0) |
| 254 | + hidden_states = torch_npu.npu_swiglu(hidden_states) |
| 255 | + |
| 256 | + w2 = w2.transpose(1, 2) |
| 257 | + hidden_states = torch_npu.npu_grouped_matmul( |
| 258 | + x=[hidden_states], |
| 259 | + weight=[w2], |
| 260 | + split_item=2, |
| 261 | + group_list_type=group_list_type, |
| 262 | + group_type=0, |
| 263 | + group_list=group_list, |
| 264 | + ) |
| 265 | + |
| 266 | + hidden_states = torch.cat(hidden_states, dim=0) |
| 267 | + return hidden_states |
| 268 | + |
| 269 | + |
151 | 270 | def fused_experts_with_all2all(
|
152 | 271 | hidden_states: torch.Tensor,
|
153 | 272 | w1: torch.Tensor,
|
@@ -283,6 +402,133 @@ def fused_experts_with_all2all(
|
283 | 402 | return final_hidden_states
|
284 | 403 |
|
285 | 404 |
|
| 405 | +# currently expert parallelism implemented with all2all |
| 406 | +# is under-optimized. |
| 407 | +def fused_experts_with_all2all_buffer( |
| 408 | + hidden_states: torch.Tensor, |
| 409 | + w1: torch.Tensor, |
| 410 | + w2: torch.Tensor, |
| 411 | + topk_weights: torch.Tensor, |
| 412 | + topk_ids: torch.Tensor, |
| 413 | + top_k: int, |
| 414 | + max_model_len: int, |
| 415 | + global_batch_size: int, |
| 416 | + expert_map: torch.Tensor = None, |
| 417 | + ep_group: GroupCoordinator = None, |
| 418 | +): |
| 419 | + original_shape = hidden_states.shape |
| 420 | + if len(original_shape) == 3: |
| 421 | + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) |
| 422 | + |
| 423 | + num_tokens, _ = hidden_states.shape |
| 424 | + device = hidden_states.device |
| 425 | + |
| 426 | + global_num_experts = len(expert_map) |
| 427 | + local_num_experts = global_num_experts // ep_group.world_size |
| 428 | + row_idx_len = num_tokens * top_k |
| 429 | + row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32, |
| 430 | + device=device).view(top_k, |
| 431 | + -1).permute(1, 0).contiguous()) |
| 432 | + hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( |
| 433 | + hidden_states, |
| 434 | + row_idx=row_idx, |
| 435 | + expert_idx=topk_ids, |
| 436 | + active_num=num_tokens) |
| 437 | + |
| 438 | + max_row_per_ep_rank = (-(-global_batch_size // ep_group.world_size) * |
| 439 | + max_model_len // ep_group.world_size + |
| 440 | + 1) * top_k * 2 |
| 441 | + expert_idx_buffer_scatter, unpad_indices = process_topk_ids( |
| 442 | + expanded_expert_idx, global_num_experts, ep_group.world_size, |
| 443 | + max_row_per_ep_rank, num_tokens, top_k) |
| 444 | + hidden_states_pad_idx = torch.zeros( |
| 445 | + expert_idx_buffer_scatter.shape, |
| 446 | + dtype=expert_idx_buffer_scatter.dtype, |
| 447 | + device=expert_idx_buffer_scatter.device) |
| 448 | + non_pad_len = torch.sum( |
| 449 | + (expert_idx_buffer_scatter != global_num_experts).to(torch.int32)) |
| 450 | + hidden_states_pad_idx[ |
| 451 | + expert_idx_buffer_scatter != global_num_experts] = torch.arange( |
| 452 | + non_pad_len, |
| 453 | + dtype=expert_idx_buffer_scatter.dtype, |
| 454 | + device=hidden_states.device) |
| 455 | + |
| 456 | + hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx] |
| 457 | + expert_idx_buffer_gather = torch.empty_like( |
| 458 | + expert_idx_buffer_scatter, |
| 459 | + dtype=expert_idx_buffer_scatter.dtype, |
| 460 | + device=expert_idx_buffer_scatter.device) |
| 461 | + hidden_states_buffer_gather = torch.empty_like( |
| 462 | + hidden_states_buffer_scatter, |
| 463 | + dtype=hidden_states_buffer_scatter.dtype, |
| 464 | + device=hidden_states_buffer_scatter.device) |
| 465 | + dist.all_to_all_single(expert_idx_buffer_gather, |
| 466 | + expert_idx_buffer_scatter, |
| 467 | + group=ep_group.device_group) |
| 468 | + dist.all_to_all_single(hidden_states_buffer_gather, |
| 469 | + hidden_states_buffer_scatter, |
| 470 | + group=ep_group.device_group) |
| 471 | + mask = expert_idx_buffer_gather != global_num_experts |
| 472 | + local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * ( |
| 473 | + global_num_experts // ep_group.world_size) |
| 474 | + hidden_states = hidden_states_buffer_gather[mask] |
| 475 | + idx_type = local_expert_idx.dtype |
| 476 | + sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float()) |
| 477 | + sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type) |
| 478 | + |
| 479 | + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( |
| 480 | + sorted_local_expert_idx, local_num_experts).to(torch.int64) |
| 481 | + hidden_states = hidden_states[sorted_idx] |
| 482 | + group_list_type = 0 |
| 483 | + |
| 484 | + hidden_states_wrapper = [hidden_states] |
| 485 | + del hidden_states |
| 486 | + |
| 487 | + hidden_states = apply_mlp(hidden_states_wrapper, |
| 488 | + w1, |
| 489 | + w2, |
| 490 | + expert_tokens, |
| 491 | + group_list_type=group_list_type) |
| 492 | + |
| 493 | + resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype) |
| 494 | + hidden_states = hidden_states[resorted_idx] |
| 495 | + hidden_states_scatter = torch.zeros( |
| 496 | + (mask.shape[0], hidden_states.shape[1]), |
| 497 | + dtype=hidden_states.dtype, |
| 498 | + device=hidden_states.device) |
| 499 | + hidden_states_scatter[mask] = hidden_states |
| 500 | + hidden_states_gatter = torch.empty_like( |
| 501 | + hidden_states_scatter, |
| 502 | + dtype=hidden_states_scatter.dtype, |
| 503 | + device=hidden_states_scatter.device) |
| 504 | + dist.all_to_all_single(hidden_states_gatter, |
| 505 | + hidden_states_scatter, |
| 506 | + group=ep_group.device_group) |
| 507 | + hidden_states_gatter = hidden_states_gatter[ |
| 508 | + expert_idx_buffer_scatter != global_num_experts] |
| 509 | + if hidden_states_gatter.shape[0] != row_idx_len: |
| 510 | + hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]), |
| 511 | + dtype=hidden_states.dtype, |
| 512 | + device=hidden_states.device) |
| 513 | + hidden_states[unpad_indices != -1] = hidden_states_gatter |
| 514 | + else: |
| 515 | + # TODO: Reorder device memory 2 times here, replace the current |
| 516 | + hidden_states = hidden_states_gatter |
| 517 | + final_hidden_states = torch_npu.npu_moe_finalize_routing( |
| 518 | + hidden_states, |
| 519 | + skip1=None, |
| 520 | + skip2=None, |
| 521 | + bias=None, |
| 522 | + scales=topk_weights, |
| 523 | + expanded_src_to_dst_row=expanded_row_idx, |
| 524 | + export_for_source_row=topk_ids, |
| 525 | + ) |
| 526 | + |
| 527 | + if len(original_shape) == 3: |
| 528 | + final_hidden_states = final_hidden_states.view(original_shape) |
| 529 | + return final_hidden_states |
| 530 | + |
| 531 | + |
286 | 532 | def fused_experts(
|
287 | 533 | hidden_states: torch.Tensor,
|
288 | 534 | w1: torch.Tensor,
|
@@ -585,6 +831,7 @@ def __init__(self, moe: MoEConfig = None):
|
585 | 831 | self.ep_size = ep_group.world_size
|
586 | 832 | self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
|
587 | 833 | self.local_batch_size = self.global_batch_size // self.ep_size
|
| 834 | + self.max_model_len = vllm_config.model_config.max_model_len |
588 | 835 |
|
589 | 836 | ascend_config = get_ascend_config()
|
590 | 837 | self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
@@ -613,21 +860,22 @@ def apply(
|
613 | 860 | self,
|
614 | 861 | layer: torch.nn.Module,
|
615 | 862 | x: torch.Tensor,
|
616 |
| - use_grouped_topk: bool, |
617 |
| - top_k: int, |
618 | 863 | router_logits: torch.Tensor,
|
| 864 | + top_k: int, |
619 | 865 | renormalize: bool,
|
620 |
| - topk_group: Optional[int] = None, |
621 |
| - num_expert_group: Optional[int] = None, |
| 866 | + use_grouped_topk: bool = False, |
622 | 867 | global_num_experts: int = -1,
|
623 | 868 | expert_map: Optional[torch.Tensor] = None,
|
| 869 | + topk_group: Optional[int] = None, |
| 870 | + num_expert_group: Optional[int] = None, |
624 | 871 | custom_routing_function: Optional[Callable] = None,
|
625 | 872 | scoring_func: str = "softmax",
|
626 | 873 | e_score_correction_bias: Optional[torch.Tensor] = None,
|
627 | 874 | is_prefill: bool = False,
|
628 | 875 | enable_force_load_balance: bool = False,
|
629 | 876 | **kwargs,
|
630 |
| - ): |
| 877 | + ) -> torch.Tensor: |
| 878 | + |
631 | 879 | # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
632 | 880 | if global_num_experts == 256:
|
633 | 881 | topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
@@ -683,11 +931,19 @@ def apply(
|
683 | 931 | topk_ids=topk_ids,
|
684 | 932 | top_k=top_k,
|
685 | 933 | expert_map=expert_map)
|
| 934 | + elif MOE_ALL2ALL_BUFFER: |
| 935 | + return fused_experts_with_all2all_buffer( |
| 936 | + hidden_states=x, |
| 937 | + w1=layer.w13_weight, |
| 938 | + w2=layer.w2_weight, |
| 939 | + topk_weights=topk_weights, |
| 940 | + topk_ids=topk_ids, |
| 941 | + top_k=top_k, |
| 942 | + max_model_len=self.max_model_len, |
| 943 | + global_batch_size=self.global_batch_size, |
| 944 | + expert_map=expert_map, |
| 945 | + ep_group=get_ep_group()) |
686 | 946 | else:
|
687 |
| - # The current implementation of deepseek moe splits hidden_states |
688 |
| - # according to tp_size before they are feed into fused_moe module. |
689 |
| - # Therefore, all2all is needed no matter how dp/tp is set so as to |
690 |
| - # dispatch/combine tokens. |
691 | 947 | return fused_experts_with_all2all(hidden_states=x,
|
692 | 948 | w1=layer.w13_weight,
|
693 | 949 | w2=layer.w2_weight,
|
|
0 commit comments