@@ -101,6 +101,41 @@ def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor,
101
101
return super ().forward_oot (x )
102
102
103
103
104
+ class CustomDeepseekV2MergedReplicatedLinear (ReplicatedLinear ):
105
+
106
+ def __init__ (
107
+ self ,
108
+ input_size : int ,
109
+ output_sizes : list [int ],
110
+ bias : bool = True ,
111
+ quant_config : Optional [QuantizationConfig ] = None ,
112
+ prefix : str = "" ,
113
+ ):
114
+ self .output_sizes = output_sizes
115
+ super ().__init__ (input_size ,
116
+ sum (output_sizes ),
117
+ bias = bias ,
118
+ quant_config = quant_config ,
119
+ prefix = prefix )
120
+
121
+ def weight_loader (self , param : torch .nn .Parameter ,
122
+ loaded_weight : torch .Tensor , loaded_shard_id : int ):
123
+ # With no support for GGUF format yet.
124
+ assert not getattr (param , "is_gguf_weight" , False )
125
+ assert not getattr (param , "is_gguf_weight_type" , False )
126
+
127
+ assert loaded_shard_id < len (self .output_sizes )
128
+ shard_offset = sum (self .output_sizes [:loaded_shard_id ])
129
+ shard_size = self .output_sizes [loaded_shard_id ]
130
+ shard = param .data .narrow (param .output_dim , shard_offset , shard_size )
131
+
132
+ assert shard .size () == loaded_weight .size (), (
133
+ f"Tried to load weights of size { loaded_weight .size ()} "
134
+ f"to a parameter shard of id { loaded_shard_id } size { shard .size ()} "
135
+ )
136
+ shard .copy_ (loaded_weight )
137
+
138
+
104
139
class CustomDeepseekV2MLP (nn .Module ):
105
140
106
141
def __init__ (
@@ -110,20 +145,33 @@ def __init__(
110
145
hidden_act : str ,
111
146
quant_config : Optional [QuantizationConfig ] = None ,
112
147
reduce_results : bool = True ,
148
+ force_replicate : bool = False ,
113
149
prefix : str = "" ,
114
150
) -> None :
115
151
super ().__init__ ()
116
- self .gate_up_proj = MergedColumnParallelLinear (
117
- hidden_size , [intermediate_size ] * 2 ,
118
- bias = False ,
119
- quant_config = quant_config ,
120
- prefix = f"{ prefix } .gate_up_proj" )
121
- self .down_proj = RowParallelLinear (intermediate_size ,
122
- hidden_size ,
123
- bias = False ,
124
- quant_config = quant_config ,
125
- reduce_results = reduce_results ,
126
- prefix = f"{ prefix } .down_proj" )
152
+ if not force_replicate :
153
+ self .gate_up_proj = MergedColumnParallelLinear (
154
+ hidden_size , [intermediate_size ] * 2 ,
155
+ bias = False ,
156
+ quant_config = quant_config ,
157
+ prefix = f"{ prefix } .gate_up_proj" )
158
+ self .down_proj = RowParallelLinear (intermediate_size ,
159
+ hidden_size ,
160
+ bias = False ,
161
+ quant_config = quant_config ,
162
+ reduce_results = reduce_results ,
163
+ prefix = f"{ prefix } .down_proj" )
164
+ else :
165
+ self .gate_up_proj = CustomDeepseekV2MergedReplicatedLinear (
166
+ hidden_size , [intermediate_size ] * 2 ,
167
+ bias = False ,
168
+ quant_config = quant_config ,
169
+ prefix = f"{ prefix } .gate_up_proj" )
170
+ self .down_proj = ReplicatedLinear (intermediate_size ,
171
+ hidden_size ,
172
+ bias = False ,
173
+ quant_config = quant_config ,
174
+ prefix = f"{ prefix } .down_proj" )
127
175
if hidden_act != "silu" :
128
176
raise ValueError (f"Unsupported activation: { hidden_act } . "
129
177
"Only silu is supported for now." )
@@ -189,6 +237,12 @@ def __init__(
189
237
raise ValueError (f"Unsupported activation: { config .hidden_act } . "
190
238
"Only silu is supported for now." )
191
239
240
+ ascend_config = get_ascend_config ()
241
+ self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
242
+ # NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
243
+ self .enable_multistream_moe = \
244
+ ascend_config .torchair_graph_config .enable_multistream_moe and VLLM_ENABLE_MC2
245
+
192
246
self .gate = ReplicatedLinear (config .hidden_size ,
193
247
config .n_routed_experts ,
194
248
bias = False ,
@@ -224,6 +278,7 @@ def __init__(
224
278
hidden_act = config .hidden_act ,
225
279
quant_config = quant_config ,
226
280
reduce_results = True ,
281
+ force_replicate = self .enable_multistream_moe ,
227
282
prefix = f"{ prefix } .shared_experts" ,
228
283
)
229
284
else :
@@ -238,12 +293,6 @@ def __init__(
238
293
239
294
self .params_dtype = torch .get_default_dtype ()
240
295
241
- ascend_config = get_ascend_config ()
242
- self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
243
- # NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
244
- self .enable_multistream_moe = \
245
- ascend_config .torchair_graph_config .enable_multistream_moe and VLLM_ENABLE_MC2
246
-
247
296
def forward (
248
297
self ,
249
298
hidden_states : torch .Tensor ,
@@ -282,27 +331,22 @@ def forward(
282
331
# router_logits: (num_tokens, n_experts)
283
332
router_logits , _ = self .gate (hidden_states )
284
333
285
- kwargs = {}
286
- if not use_separated_shared_experts :
287
- kwargs .update ({
288
- "shared_experts" : self .shared_experts ,
289
- "shared_experts_input" : old_hidden_states
290
- })
291
-
292
334
experts_hidden_states = self .experts (
293
335
hidden_states = hidden_states ,
294
336
router_logits = router_logits ,
295
337
is_prefill = is_prefill ,
296
338
top_k = CustomDeepseekV2MoE .top_k ,
297
339
enable_force_load_balance = enable_force_load_balance ,
298
- ** kwargs )
340
+ shared_experts = (self .shared_experts
341
+ if not use_separated_shared_experts else None ),
342
+ )
299
343
300
344
if not isinstance (experts_hidden_states , tuple ):
301
345
hidden_states = experts_hidden_states * self .routed_scaling_factor
302
346
else :
303
- hidden_states = experts_hidden_states [
304
- 0 ] * self .routed_scaling_factor
305
- shared_hidden_states = experts_hidden_states [1 ]
347
+ hidden_states = (
348
+ experts_hidden_states [ 0 ] * self .routed_scaling_factor +
349
+ experts_hidden_states [1 ])
306
350
307
351
if self .tp_size > 1 :
308
352
if (VLLM_ENABLE_MC2
@@ -317,10 +361,8 @@ def forward(
317
361
hidden_states = tensor_model_parallel_all_reduce (hidden_states )
318
362
319
363
if use_separated_shared_experts :
320
- shared_hidden_states = self .shared_experts (old_hidden_states )
321
-
322
- if self .shared_experts is not None :
323
- hidden_states = hidden_states + shared_hidden_states
364
+ hidden_states = hidden_states + self .shared_experts (
365
+ old_hidden_states )
324
366
325
367
return hidden_states .view (num_tokens , hidden_size )
326
368
0 commit comments