|
26 | 26 | select_experts)
|
27 | 27 | from vllm_ascend.utils import is_310p
|
28 | 28 |
|
29 |
| -original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__ |
30 | 29 |
|
| 30 | +@UnquantizedFusedMoEMethod.register_oot |
| 31 | +class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): |
| 32 | + """This UnquantizedFusedMoEMethod is used for qwen3-moe. |
| 33 | + Customize it mainly to support aclgraph |
| 34 | + """ |
31 | 35 |
|
32 |
| -def unquantized_fused_moe_init_func(self, *args, **kwargs): |
33 |
| - original_unquantized_fused_moe_init_func(self, *args, **kwargs) |
34 |
| - vllm_config = get_current_vllm_config() |
35 |
| - self.max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens |
36 |
| - self.use_aclgraph = vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not vllm_config.model_config.enforce_eager |
| 36 | + def __init__(self, *args, **kwargs): |
| 37 | + super().__init__(self, *args, **kwargs) |
| 38 | + vllm_config = get_current_vllm_config() |
| 39 | + self.max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens |
| 40 | + self.use_aclgraph = vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not vllm_config.model_config.enforce_eager |
37 | 41 |
|
| 42 | + def forward_oot( |
| 43 | + self, |
| 44 | + layer: torch.nn.Module, |
| 45 | + x: torch.Tensor, |
| 46 | + use_grouped_topk: bool, |
| 47 | + top_k: int, |
| 48 | + router_logits: torch.Tensor, |
| 49 | + renormalize: bool, |
| 50 | + topk_group: Optional[int] = None, |
| 51 | + num_expert_group: Optional[int] = None, |
| 52 | + custom_routing_function: Optional[Callable] = None, |
| 53 | + scoring_func: str = "softmax", |
| 54 | + e_score_correction_bias: Optional[torch.Tensor] = None, |
| 55 | + global_num_experts: Optional[int] = None, |
| 56 | + expert_map: Optional[torch.Tensor] = None, |
| 57 | + apply_router_weight_on_input: bool = False, |
| 58 | + activation: str = "silu", |
| 59 | + ) -> torch.Tensor: |
| 60 | + topk_weights, topk_ids = select_experts( |
| 61 | + global_num_experts=global_num_experts, |
| 62 | + hidden_states=x, |
| 63 | + router_logits=router_logits, |
| 64 | + top_k=top_k, |
| 65 | + use_grouped_topk=use_grouped_topk, |
| 66 | + renormalize=renormalize, |
| 67 | + topk_group=topk_group, |
| 68 | + num_expert_group=num_expert_group, |
| 69 | + custom_routing_function=custom_routing_function, |
| 70 | + scoring_func=scoring_func, |
| 71 | + e_score_correction_bias=e_score_correction_bias, |
| 72 | + ) |
38 | 73 |
|
39 |
| -def forward_oot( |
40 |
| - self, |
41 |
| - layer: torch.nn.Module, |
42 |
| - x: torch.Tensor, |
43 |
| - use_grouped_topk: bool, |
44 |
| - top_k: int, |
45 |
| - router_logits: torch.Tensor, |
46 |
| - renormalize: bool, |
47 |
| - topk_group: Optional[int] = None, |
48 |
| - num_expert_group: Optional[int] = None, |
49 |
| - custom_routing_function: Optional[Callable] = None, |
50 |
| - scoring_func: str = "softmax", |
51 |
| - e_score_correction_bias: Optional[torch.Tensor] = None, |
52 |
| - global_num_experts: Optional[int] = None, |
53 |
| - expert_map: Optional[torch.Tensor] = None, |
54 |
| - apply_router_weight_on_input: bool = False, |
55 |
| - activation: str = "silu", |
56 |
| -) -> torch.Tensor: |
57 |
| - topk_weights, topk_ids = select_experts( |
58 |
| - global_num_experts=global_num_experts, |
59 |
| - hidden_states=x, |
60 |
| - router_logits=router_logits, |
61 |
| - top_k=top_k, |
62 |
| - use_grouped_topk=use_grouped_topk, |
63 |
| - renormalize=renormalize, |
64 |
| - topk_group=topk_group, |
65 |
| - num_expert_group=num_expert_group, |
66 |
| - custom_routing_function=custom_routing_function, |
67 |
| - scoring_func=scoring_func, |
68 |
| - e_score_correction_bias=e_score_correction_bias, |
69 |
| - ) |
| 74 | + if topk_ids.shape[1] < top_k or is_310p(): |
| 75 | + assert global_num_experts is not None |
| 76 | + return fused_experts_moge( |
| 77 | + hidden_states=x, |
| 78 | + w1=layer.w13_weight, |
| 79 | + w2=layer.w2_weight, |
| 80 | + topk_weights=topk_weights, |
| 81 | + topk_ids=topk_ids, |
| 82 | + top_k=top_k, |
| 83 | + global_num_experts=global_num_experts, |
| 84 | + expert_map=expert_map, |
| 85 | + apply_router_weight_on_input=apply_router_weight_on_input) |
70 | 86 |
|
71 |
| - if topk_ids.shape[1] < top_k or is_310p(): |
72 |
| - assert global_num_experts is not None |
73 |
| - return fused_experts_moge( |
| 87 | + # If use aclgraph, we need to set max_num_tokens to make |
| 88 | + # the input shape of `npu_moe_init_routing` fixed |
| 89 | + max_num_tokens = self.max_num_batched_tokens if self.use_aclgraph else None |
| 90 | + |
| 91 | + return fused_experts( |
74 | 92 | hidden_states=x,
|
75 | 93 | w1=layer.w13_weight,
|
76 | 94 | w2=layer.w2_weight,
|
77 | 95 | topk_weights=topk_weights,
|
78 | 96 | topk_ids=topk_ids,
|
79 | 97 | top_k=top_k,
|
80 |
| - global_num_experts=global_num_experts, |
81 | 98 | expert_map=expert_map,
|
82 |
| - apply_router_weight_on_input=apply_router_weight_on_input) |
83 |
| - |
84 |
| - # If use aclgraph, we need to set max_num_tokens to make |
85 |
| - # the input shape of `npu_moe_init_routing` fixed |
86 |
| - max_num_tokens = self.max_num_batched_tokens if self.use_aclgraph else None |
87 |
| - |
88 |
| - return fused_experts( |
89 |
| - hidden_states=x, |
90 |
| - w1=layer.w13_weight, |
91 |
| - w2=layer.w2_weight, |
92 |
| - topk_weights=topk_weights, |
93 |
| - topk_ids=topk_ids, |
94 |
| - top_k=top_k, |
95 |
| - expert_map=expert_map, |
96 |
| - apply_router_weight_on_input=apply_router_weight_on_input, |
97 |
| - max_num_tokens=max_num_tokens) |
98 |
| - |
99 |
| - |
100 |
| -UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func |
101 |
| -UnquantizedFusedMoEMethod.forward_oot = forward_oot |
| 99 | + apply_router_weight_on_input=apply_router_weight_on_input, |
| 100 | + max_num_tokens=max_num_tokens) |
0 commit comments