15
15
# limitations under the License.
16
16
# Adapted from vllm/model_executor/models/qwen3_moe.py
17
17
# This file is a part of the vllm-ascend project.
18
- from typing import Optional
19
18
20
- import torch
21
- import vllm
22
- from torch import nn
23
- from transformers import PretrainedConfig
24
- from vllm .attention import AttentionMetadata
25
- from vllm .distributed import get_tensor_model_parallel_world_size , get_tp_group
26
- from vllm .distributed .parallel_state import get_dp_group
27
- from vllm .forward_context import get_forward_context
28
- from vllm .model_executor .layers .linear import ReplicatedLinear
29
- from vllm .model_executor .layers .quantization import QuantizationConfig
30
19
from vllm .model_executor .models .qwen3_moe import Qwen3MoeForCausalLM
31
- from vllm .distributed .parallel_state import get_ep_group
32
- from vllm .forward_context import get_forward_context
33
20
34
21
35
- from vllm_ascend .ascend_config import get_ascend_config
36
- from vllm_ascend .ops .fused_moe import AscendFusedMoE
37
-
38
22
class CustomQwen3MoeForCausalLM (Qwen3MoeForCausalLM ):
39
23
packed_modules_mapping = {
40
24
"qkv_proj" : [
@@ -49,86 +33,3 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
49
33
"experts" :
50
34
["experts.0.gate_proj" , "experts.0.up_proj" , "experts.0.down_proj" ],
51
35
}
52
-
53
-
54
- class AscendQwen3MoeSparseMoeBlock (nn .Module ):
55
- top_k : int
56
-
57
- def __init__ (
58
- self ,
59
- config : PretrainedConfig ,
60
- quant_config : Optional [QuantizationConfig ] = None ,
61
- prefix : str = "" ,
62
- ):
63
- super ().__init__ ()
64
- self .tp_size = get_tensor_model_parallel_world_size ()
65
- if self .tp_size > config .num_experts :
66
- raise ValueError (
67
- f"Tensor parallel size { self .tp_size } is greater than "
68
- f"the number of experts { config .num_experts } ." )
69
-
70
- ascend_config = get_ascend_config ()
71
- self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
72
- self .enable_multistream_moe = \
73
- ascend_config .torchair_graph_config .enable_multistream_moe
74
-
75
- self .gate = ReplicatedLinear (config .hidden_size ,
76
- config .num_experts ,
77
- bias = False ,
78
- quant_config = None ,
79
- prefix = f"{ prefix } .gate" )
80
-
81
- self .experts = AscendFusedMoE (
82
- num_experts = config .num_experts ,
83
- top_k = config .num_experts_per_tok ,
84
- hidden_size = config .hidden_size ,
85
- intermediate_size = config .moe_intermediate_size ,
86
- reduce_results = False ,
87
- renormalize = config .norm_topk_prob ,
88
- quant_config = quant_config ,
89
- prefix = f"{ prefix } .experts" )
90
-
91
- self .top_k = config .num_experts_per_tok
92
-
93
- self .dp_size = get_dp_group ().world_size
94
-
95
- self .tp_group = get_tp_group ().device_group
96
- self .tp_rank = get_tp_group ().rank_in_group
97
- self .ep_group = get_ep_group ()
98
-
99
- self .params_dtype = torch .get_default_dtype ()
100
-
101
- def forward (
102
- self ,
103
- hidden_states : torch .Tensor ,
104
- attn_metadata : Optional [AttentionMetadata ] = None ) -> torch .Tensor :
105
- if attn_metadata is None :
106
- attn_metadata = get_forward_context ().attn_metadata
107
- # when profile runs, force experts to load balanced tokens
108
- # to avoid high memory consumption on a single rank.
109
- # TODO: need a better flag to indicate whether in profile run or not.
110
- if attn_metadata is None :
111
- # for profile run
112
- is_prefill = True
113
- enable_force_load_balance = True
114
- else :
115
- is_prefill = get_forward_context ().with_prefill
116
- enable_force_load_balance = False
117
- # if hasattr(attn_metadata, 'with_prefill_across_dp'):
118
- # is_prefill = attn_metadata.with_prefill_across_dp
119
-
120
- # router_logits: (num_tokens, n_experts)
121
- router_logits , _ = self .gate (hidden_states )
122
-
123
- hidden_states = self .experts (
124
- hidden_states = hidden_states ,
125
- router_logits = router_logits ,
126
- is_prefill = is_prefill ,
127
- top_k = self .top_k ,
128
- enable_force_load_balance = enable_force_load_balance ,
129
- shared_experts = None )
130
-
131
- return hidden_states
132
-
133
-
134
- vllm .model_executor .models .qwen3_moe .Qwen3MoeSparseMoeBlock = AscendQwen3MoeSparseMoeBlock
0 commit comments