@@ -98,6 +98,41 @@ def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor,
98
98
return super ().forward_oot (x )
99
99
100
100
101
+ class CustomDeepseekV2MergedReplicatedLinear (ReplicatedLinear ):
102
+
103
+ def __init__ (
104
+ self ,
105
+ input_size : int ,
106
+ output_sizes : list [int ],
107
+ bias : bool = True ,
108
+ quant_config : Optional [QuantizationConfig ] = None ,
109
+ prefix : str = "" ,
110
+ ):
111
+ self .output_sizes = output_sizes
112
+ super ().__init__ (input_size ,
113
+ sum (output_sizes ),
114
+ bias = bias ,
115
+ quant_config = quant_config ,
116
+ prefix = prefix )
117
+
118
+ def weight_loader (self , param : torch .nn .Parameter ,
119
+ loaded_weight : torch .Tensor , loaded_shard_id : int ):
120
+ # With no support for GGUF format yet.
121
+ assert not getattr (param , "is_gguf_weight" , False )
122
+ assert not getattr (param , "is_gguf_weight_type" , False )
123
+
124
+ assert loaded_shard_id < len (self .output_sizes )
125
+ shard_offset = sum (self .output_sizes [:loaded_shard_id ])
126
+ shard_size = self .output_sizes [loaded_shard_id ]
127
+ shard = param .data .narrow (param .output_dim , shard_offset , shard_size )
128
+
129
+ assert shard .size () == loaded_weight .size (), (
130
+ f"Tried to load weights of size { loaded_weight .size ()} "
131
+ f"to a parameter shard of id { loaded_shard_id } size { shard .size ()} "
132
+ )
133
+ shard .copy_ (loaded_weight )
134
+
135
+
101
136
class CustomDeepseekV2MLP (nn .Module ):
102
137
103
138
def __init__ (
@@ -107,20 +142,33 @@ def __init__(
107
142
hidden_act : str ,
108
143
quant_config : Optional [QuantizationConfig ] = None ,
109
144
reduce_results : bool = True ,
145
+ force_replicate : bool = False ,
110
146
prefix : str = "" ,
111
147
) -> None :
112
148
super ().__init__ ()
113
- self .gate_up_proj = MergedColumnParallelLinear (
114
- hidden_size , [intermediate_size ] * 2 ,
115
- bias = False ,
116
- quant_config = quant_config ,
117
- prefix = f"{ prefix } .gate_up_proj" )
118
- self .down_proj = RowParallelLinear (intermediate_size ,
119
- hidden_size ,
120
- bias = False ,
121
- quant_config = quant_config ,
122
- reduce_results = reduce_results ,
123
- prefix = f"{ prefix } .down_proj" )
149
+ if not force_replicate :
150
+ self .gate_up_proj = MergedColumnParallelLinear (
151
+ hidden_size , [intermediate_size ] * 2 ,
152
+ bias = False ,
153
+ quant_config = quant_config ,
154
+ prefix = f"{ prefix } .gate_up_proj" )
155
+ self .down_proj = RowParallelLinear (intermediate_size ,
156
+ hidden_size ,
157
+ bias = False ,
158
+ quant_config = quant_config ,
159
+ reduce_results = reduce_results ,
160
+ prefix = f"{ prefix } .down_proj" )
161
+ else :
162
+ self .gate_up_proj = CustomDeepseekV2MergedReplicatedLinear (
163
+ hidden_size , [intermediate_size ] * 2 ,
164
+ bias = False ,
165
+ quant_config = quant_config ,
166
+ prefix = f"{ prefix } .gate_up_proj" )
167
+ self .down_proj = ReplicatedLinear (intermediate_size ,
168
+ hidden_size ,
169
+ bias = False ,
170
+ quant_config = quant_config ,
171
+ prefix = f"{ prefix } .down_proj" )
124
172
if hidden_act != "silu" :
125
173
raise ValueError (f"Unsupported activation: { hidden_act } . "
126
174
"Only silu is supported for now." )
@@ -181,6 +229,12 @@ def __init__(
181
229
raise ValueError (f"Unsupported activation: { config .hidden_act } . "
182
230
"Only silu is supported for now." )
183
231
232
+ ascend_config = get_ascend_config ()
233
+ self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
234
+ # NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
235
+ self .enable_multistream_moe = \
236
+ ascend_config .torchair_graph_config .enable_multistream_moe and VLLM_ENABLE_MC2
237
+
184
238
self .gate = ReplicatedLinear (config .hidden_size ,
185
239
config .n_routed_experts ,
186
240
bias = False ,
@@ -216,6 +270,7 @@ def __init__(
216
270
hidden_act = config .hidden_act ,
217
271
quant_config = quant_config ,
218
272
reduce_results = True ,
273
+ force_replicate = self .enable_multistream_moe ,
219
274
prefix = f"{ prefix } .shared_experts" ,
220
275
)
221
276
else :
@@ -230,12 +285,6 @@ def __init__(
230
285
231
286
self .params_dtype = torch .get_default_dtype ()
232
287
233
- ascend_config = get_ascend_config ()
234
- self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
235
- # NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
236
- self .enable_multistream_moe = \
237
- ascend_config .torchair_graph_config .enable_multistream_moe and VLLM_ENABLE_MC2
238
-
239
288
def forward (
240
289
self ,
241
290
hidden_states : torch .Tensor ,
@@ -274,27 +323,22 @@ def forward(
274
323
# router_logits: (num_tokens, n_experts)
275
324
router_logits , _ = self .gate (hidden_states )
276
325
277
- kwargs = {}
278
- if not use_separated_shared_experts :
279
- kwargs .update ({
280
- "shared_experts" : self .shared_experts ,
281
- "shared_experts_input" : old_hidden_states
282
- })
283
-
284
326
experts_hidden_states = self .experts (
285
327
hidden_states = hidden_states ,
286
328
router_logits = router_logits ,
287
329
is_prefill = is_prefill ,
288
330
top_k = CustomDeepseekV2MoE .top_k ,
289
331
enable_force_load_balance = enable_force_load_balance ,
290
- ** kwargs )
332
+ shared_experts = (self .shared_experts
333
+ if not use_separated_shared_experts else None ),
334
+ )
291
335
292
336
if not isinstance (experts_hidden_states , tuple ):
293
337
hidden_states = experts_hidden_states * self .routed_scaling_factor
294
338
else :
295
- hidden_states = experts_hidden_states [
296
- 0 ] * self .routed_scaling_factor
297
- shared_hidden_states = experts_hidden_states [1 ]
339
+ hidden_states = (
340
+ experts_hidden_states [ 0 ] * self .routed_scaling_factor +
341
+ experts_hidden_states [1 ])
298
342
299
343
if self .tp_size > 1 :
300
344
if (VLLM_ENABLE_MC2
@@ -309,10 +353,8 @@ def forward(
309
353
hidden_states = tensor_model_parallel_all_reduce (hidden_states )
310
354
311
355
if use_separated_shared_experts :
312
- shared_hidden_states = self .shared_experts (old_hidden_states )
313
-
314
- if self .shared_experts is not None :
315
- hidden_states = hidden_states + shared_hidden_states
356
+ hidden_states = hidden_states + self .shared_experts (
357
+ old_hidden_states )
316
358
317
359
return hidden_states .view (num_tokens , hidden_size )
318
360
0 commit comments