33
33
from torch import nn
34
34
from transformers import PretrainedConfig
35
35
from vllm .attention import AttentionMetadata
36
- from vllm .config import (CacheConfig , ModelConfig , VllmConfig ,
37
- get_current_vllm_config )
36
+ from vllm .config import CacheConfig , ModelConfig , VllmConfig
38
37
from vllm .distributed import (get_ep_group , get_pp_group ,
39
38
get_tensor_model_parallel_rank ,
40
39
get_tensor_model_parallel_world_size ,
41
40
get_tp_group , tensor_model_parallel_all_reduce )
42
41
from vllm .distributed .parallel_state import get_dp_group
43
42
from vllm .forward_context import get_forward_context
44
43
from vllm .model_executor .layers .layernorm import RMSNorm
45
- from vllm .model_executor .layers .linear import (ReplicatedLinear ,
46
- UnquantizedLinearMethod )
44
+ from vllm .model_executor .layers .linear import UnquantizedLinearMethod
47
45
from vllm .model_executor .layers .logits_processor import LogitsProcessor
48
46
from vllm .model_executor .layers .quantization import QuantizationConfig
49
47
from vllm .model_executor .layers .sampler import get_sampler
50
48
from vllm .model_executor .layers .vocab_parallel_embedding import (
51
49
ParallelLMHead , VocabParallelEmbedding )
52
50
from vllm .model_executor .models .deepseek_v2 import \
53
51
DeepseekV2ForCausalLM # noqa: E501
54
- from vllm .model_executor .models .deepseek_v2 import DeepseekV2DecoderLayer
55
52
from vllm .model_executor .models .utils import (
56
53
PPMissingLayer , make_empty_intermediate_tensors_factory , make_layers ,
57
54
maybe_prefix )
58
55
from vllm .sequence import IntermediateTensors
59
56
60
57
import vllm_ascend .envs as envs_ascend
61
- from vllm_ascend .ascend_config import get_ascend_config
62
58
from vllm_ascend .ascend_forward_context import FusedMoEState
63
- from vllm_ascend .models .deepseek_v2 import (CustomDeepseekV2MLAAttention ,
64
- CustomDeepseekV2MLP )
59
+ from vllm_ascend .models .deepseek_v2 import (CustomDeepseekV2DecoderLayer ,
60
+ CustomDeepseekV2MLP ,
61
+ CustomDeepseekV2MoE )
65
62
from vllm_ascend .multistream .base import MSEventKey
66
63
from vllm_ascend .multistream .context import (
67
64
advance_step_multistream_layer_context , get_multistream_comm_context ,
71
68
from vllm_ascend .multistream .metadata import (MultiStreamConfig ,
72
69
MultiStreamStepMetadata ,
73
70
make_multistream_metadata_ds )
74
- from vllm_ascend .ops .fused_moe import AscendFusedMoE
75
71
from vllm_ascend .quantization .w8a8_dynamic import (
76
72
AscendW8A8DynamicLinearMethod , apply_mlp )
77
73
from vllm_ascend .utils import dispose_tensor
@@ -126,7 +122,7 @@ def _forward_ms_mlp(self, x):
126
122
return x
127
123
128
124
129
- class CustomDeepseekDBOMoE (nn . Module ):
125
+ class CustomDeepseekDBOMoE (CustomDeepseekV2MoE ):
130
126
131
127
top_k : int
132
128
@@ -136,45 +132,9 @@ def __init__(
136
132
quant_config : Optional [QuantizationConfig ] = None ,
137
133
prefix : str = "" ,
138
134
):
139
- super ().__init__ ()
140
- self .tp_size = get_tensor_model_parallel_world_size ()
141
- self .routed_scaling_factor = config .routed_scaling_factor
142
- self .n_shared_experts = config .n_shared_experts
143
- self .routed_scaling_factor = config .routed_scaling_factor
144
- if self .tp_size > config .n_routed_experts :
145
- raise ValueError (
146
- f"Tensor parallel size { self .tp_size } is greater than "
147
- f"the number of experts { config .n_routed_experts } ." )
148
-
149
- if config .hidden_act != "silu" :
150
- raise ValueError (f"Unsupported activation: { config .hidden_act } . "
151
- "Only silu is supported for now." )
152
-
153
- self .gate = ReplicatedLinear (config .hidden_size ,
154
- config .n_routed_experts ,
155
- bias = False ,
156
- quant_config = None ,
157
- prefix = f"{ prefix } .gate" )
158
- if config .topk_method == "noaux_tc" :
159
- self .gate .e_score_correction_bias = nn .Parameter (
160
- torch .empty (config .n_routed_experts ))
161
- else :
162
- self .gate .e_score_correction_bias = None
163
-
164
- self .experts = AscendFusedMoE (
165
- num_experts = config .n_routed_experts ,
166
- top_k = config .num_experts_per_tok ,
167
- hidden_size = config .hidden_size ,
168
- intermediate_size = config .moe_intermediate_size ,
169
- reduce_results = False ,
170
- renormalize = config .norm_topk_prob ,
171
- quant_config = quant_config ,
172
- use_grouped_topk = True ,
173
- num_expert_group = config .n_group ,
174
- topk_group = config .topk_group ,
175
- prefix = f"{ prefix } .experts" ,
176
- scoring_func = config .scoring_func ,
177
- e_score_correction_bias = self .gate .e_score_correction_bias )
135
+ super ().__init__ (config = config ,
136
+ quant_config = quant_config ,
137
+ prefix = prefix )
178
138
179
139
if config .n_shared_experts is not None :
180
140
intermediate_size = (config .moe_intermediate_size *
@@ -189,19 +149,6 @@ def __init__(
189
149
)
190
150
CustomDeepseekDBOMoE .top_k = config .num_experts_per_tok
191
151
192
- self .dp_size = get_dp_group ().world_size
193
-
194
- self .tp_group = get_tp_group ().device_group
195
- self .tp_rank = get_tp_group ().rank_in_group
196
- self .kv_consumer = None
197
- transfer_config = get_current_vllm_config ().kv_transfer_config
198
- if transfer_config is not None :
199
- self .kv_consumer = transfer_config .kv_role = "kv_consumer"
200
- self .params_dtype = torch .get_default_dtype ()
201
-
202
- ascend_config = get_ascend_config ()
203
- self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
204
-
205
152
def forward (
206
153
self ,
207
154
hidden_states : torch .Tensor ,
@@ -254,7 +201,7 @@ def _forward_ms_op_gate(
254
201
return router_logits
255
202
256
203
257
- class CustomDeepseekDBODecoderLayer (DeepseekV2DecoderLayer ):
204
+ class CustomDeepseekDBODecoderLayer (CustomDeepseekV2DecoderLayer ):
258
205
259
206
def __init__ (
260
207
self ,
@@ -264,43 +211,19 @@ def __init__(
264
211
cache_config : Optional [CacheConfig ] = None ,
265
212
quant_config : Optional [QuantizationConfig ] = None ,
266
213
) -> None :
267
- nn .Module .__init__ (self )
268
- self .hidden_size = config .hidden_size
269
- rope_theta = getattr (config , "rope_theta" , 10000 )
270
- rope_scaling = getattr (config , "rope_scaling" , None )
271
- max_position_embeddings = getattr (config , "max_position_embeddings" ,
272
- 8192 )
273
- # DecoderLayers are created with `make_layers` which passes the prefix
274
- # with the layer's index.
275
- layer_idx = int (prefix .split (sep = '.' )[- 1 ])
276
- self .layer_idx = layer_idx
277
- # TODO: enable mla in vllm-ascend
278
- attn_cls = CustomDeepseekV2MLAAttention
279
- self .self_attn = attn_cls (
280
- config = config ,
281
- hidden_size = self .hidden_size ,
282
- num_heads = config .num_attention_heads ,
283
- qk_nope_head_dim = config .qk_nope_head_dim ,
284
- qk_rope_head_dim = config .qk_rope_head_dim ,
285
- v_head_dim = config .v_head_dim ,
286
- q_lora_rank = config .q_lora_rank
287
- if hasattr (config , "q_lora_rank" ) else None ,
288
- kv_lora_rank = config .kv_lora_rank ,
289
- rope_theta = rope_theta ,
290
- rope_scaling = rope_scaling ,
291
- max_position_embeddings = max_position_embeddings ,
292
- cache_config = cache_config ,
293
- quant_config = quant_config ,
294
- prefix = f"{ prefix } .self_attn" ,
295
- )
214
+ super ().__init__ (config = config ,
215
+ prefix = prefix ,
216
+ model_config = model_config ,
217
+ cache_config = cache_config ,
218
+ quant_config = quant_config )
296
219
self .tp_size = get_tensor_model_parallel_world_size ()
297
220
self .dp_size = get_dp_group ().world_size
298
221
self .tp_group = get_tp_group ().device_group
299
222
self .global_num_experts = config .n_routed_experts
300
223
301
224
if (config .n_routed_experts is not None
302
- and layer_idx >= config .first_k_dense_replace
303
- and layer_idx % config .moe_layer_freq == 0 ):
225
+ and self . layer_idx >= config .first_k_dense_replace
226
+ and self . layer_idx % config .moe_layer_freq == 0 ):
304
227
self .mlp = CustomDeepseekDBOMoE (
305
228
config = config ,
306
229
quant_config = quant_config ,
@@ -314,11 +237,6 @@ def __init__(
314
237
quant_config = quant_config ,
315
238
prefix = f"{ prefix } .mlp" ,
316
239
)
317
- self .input_layernorm = RMSNorm (config .hidden_size ,
318
- eps = config .rms_norm_eps )
319
- self .post_attention_layernorm = RMSNorm (config .hidden_size ,
320
- eps = config .rms_norm_eps )
321
- self .routed_scaling_factor = config .routed_scaling_factor
322
240
323
241
def forward (
324
242
self ,
@@ -926,7 +844,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
926
844
if get_pp_group ().is_last_rank :
927
845
self .lm_head = ParallelLMHead (config .vocab_size ,
928
846
config .hidden_size ,
929
- quant_config = quant_config )
847
+ quant_config = quant_config ,
848
+ prefix = maybe_prefix (
849
+ prefix , "lm_head" ))
930
850
else :
931
851
self .lm_head = PPMissingLayer ()
932
852
self .logits_processor = LogitsProcessor (config .vocab_size )
0 commit comments