16
16
# Adapted from vllm/model_executor/models/qwen3_moe.py
17
17
# This file is a part of the vllm-ascend project.
18
18
19
+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
20
+
21
+ import torch
22
+ import torch .distributed as dist
23
+ import torch_npu
24
+ import vllm
25
+ import vllm .envs as envs
26
+ from torch import nn
27
+ from transformers import PretrainedConfig
28
+ from vllm .attention import AttentionMetadata
29
+ from vllm .distributed import (get_tensor_model_parallel_world_size ,
30
+ get_tp_group )
31
+ from vllm .distributed .parallel_state import get_dp_group
32
+ from vllm .forward_context import get_forward_context
33
+ from vllm .model_executor .layers .linear import ReplicatedLinear
34
+
35
+ from vllm .model_executor .layers .quantization import QuantizationConfig
36
+
37
+ from vllm_ascend .ascend_config import get_ascend_config
38
+ from vllm_ascend .distributed .parallel_state import get_ep_group
39
+ from vllm_ascend .ops .fused_moe import AscendFusedMoE
40
+
19
41
from vllm .model_executor .models .qwen3_moe import Qwen3MoeForCausalLM
42
+ from transformers import PretrainedConfig
43
+ from vllm .model_executor .layers .quantization import QuantizationConfig
20
44
21
45
22
46
class CustomQwen3MoeForCausalLM (Qwen3MoeForCausalLM ):
@@ -33,3 +57,89 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
33
57
"experts" :
34
58
["experts.0.gate_proj" , "experts.0.up_proj" , "experts.0.down_proj" ],
35
59
}
60
+
61
+
62
+ class AscendQwen3MoeSparseMoeBlock (nn .Module ):
63
+
64
+ top_k : int
65
+
66
+ def __init__ (
67
+ self ,
68
+ config : PretrainedConfig ,
69
+ quant_config : Optional [QuantizationConfig ] = None ,
70
+ prefix : str = "" ,
71
+ ):
72
+ super ().__init__ ()
73
+ self .tp_size = get_tensor_model_parallel_world_size ()
74
+ if self .tp_size > config .num_experts :
75
+ raise ValueError (
76
+ f"Tensor parallel size { self .tp_size } is greater than "
77
+ f"the number of experts { config .num_experts } ." )
78
+
79
+ ascend_config = get_ascend_config ()
80
+ self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
81
+ self .enable_multistream_moe = \
82
+ ascend_config .torchair_graph_config .enable_multistream_moe
83
+
84
+ self .gate = ReplicatedLinear (config .hidden_size ,
85
+ config .num_experts ,
86
+ bias = False ,
87
+ quant_config = None ,
88
+ prefix = f"{ prefix } .gate" )
89
+
90
+ self .experts = AscendFusedMoE (
91
+ num_experts = config .num_experts ,
92
+ top_k = config .num_experts_per_tok ,
93
+ hidden_size = config .hidden_size ,
94
+ intermediate_size = config .moe_intermediate_size ,
95
+ reduce_results = False ,
96
+ renormalize = config .norm_topk_prob ,
97
+ quant_config = quant_config ,
98
+ prefix = f"{ prefix } .experts" )
99
+
100
+
101
+ self .top_k = config .num_experts_per_tok
102
+
103
+ self .dp_size = get_dp_group ().world_size
104
+
105
+ self .tp_group = get_tp_group ().device_group
106
+ self .tp_rank = get_tp_group ().rank_in_group
107
+ self .ep_group = get_ep_group ()
108
+
109
+ self .params_dtype = torch .get_default_dtype ()
110
+
111
+ def forward (
112
+ self ,
113
+ hidden_states : torch .Tensor ,
114
+ attn_metadata : Optional [AttentionMetadata ] = None ) -> torch .Tensor :
115
+ if attn_metadata is None :
116
+ attn_metadata = get_forward_context ().attn_metadata
117
+ # when profile runs, force experts to load balanced tokens
118
+ # to avoid high memory consumption on a single rank.
119
+ # TODO: need a better flag to indicate whether in profile run or not.
120
+ if attn_metadata is None :
121
+ # for profile run
122
+ is_prefill = True
123
+ enable_force_load_balance = True
124
+ else :
125
+ # is_prefill = attn_metadata.num_prefills > 0 is_prefill or
126
+ enable_force_load_balance = False
127
+ if hasattr (attn_metadata , 'with_prefill_across_dp' ):
128
+ is_prefill = attn_metadata .with_prefill_across_dp
129
+
130
+ # router_logits: (num_tokens, n_experts)
131
+ router_logits , _ = self .gate (hidden_states )
132
+
133
+ hidden_states = self .experts (
134
+ hidden_states = hidden_states ,
135
+ router_logits = router_logits ,
136
+ is_prefill = is_prefill ,
137
+ top_k = self .top_k ,
138
+ enable_force_load_balance = enable_force_load_balance ,
139
+ shared_experts = None ,
140
+ )
141
+
142
+ return hidden_states
143
+
144
+
145
+ vllm .model_executor .models .qwen3_moe .Qwen3MoeSparseMoeBlock = AscendQwen3MoeSparseMoeBlock
0 commit comments