@@ -211,7 +211,7 @@ def forward(
211
211
top_scores = (
212
212
top_scores * self .route_sclaing_factor
213
213
) # must multiply the scaling factor
214
-
214
+ print ( "In TokenChoiceTopKRouter, top_scores shape: " , top_scores )
215
215
return top_scores , token_indices_experts_sorted , num_local_tokens_per_expert
216
216
217
217
def init_weights (self , init_std : float ):
@@ -253,12 +253,6 @@ def __init__(self, model_args: DeepSeekV3ModelArgs):
253
253
)
254
254
if model_args .n_shared_experts > 0
255
255
else None
256
- # FeedForward(
257
- # dim=dim,
258
- # hidden_dim=hidden_dim * model_args.n_shared_experts,
259
- # )
260
- # if model_args.n_shared_experts > 0
261
- # else None
262
256
)
263
257
264
258
# auxiliary-loss-free load balancing
@@ -298,6 +292,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
298
292
Returns:
299
293
out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``.
300
294
"""
295
+ print ("In MoE input, x shape: " , x )
301
296
bs , slen , dim = x .shape
302
297
303
298
# top_scores and selected_indices shape (bs*slen*top_k,)
@@ -308,14 +303,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
308
303
num_local_tokens_per_expert ,
309
304
) = self .router (x .reshape (bs * slen , dim ), self .expert_bias )
310
305
311
- print (
312
- "In MoE, top_scores shape: " ,
313
- top_scores .shape ,
314
- "token_indices: " ,
315
- token_indices .shape ,
316
- "num_local_tokens: " ,
317
- num_local_tokens_per_expert .shape ,
318
- )
306
+ # print(
307
+ # "In MoE, top_scores shape: ",
308
+ # top_scores.shape,
309
+ # "token_indices: ",
310
+ # token_indices.shape,
311
+ # "num_local_tokens: ",
312
+ # num_local_tokens_per_expert.shape,
313
+ # )
319
314
320
315
# will be used to update the expert bias for load balancing
321
316
self .tokens_per_expert += num_local_tokens_per_expert
@@ -329,6 +324,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
329
324
dim = 0 ,
330
325
index = token_indices ,
331
326
)
327
+ print ("Routed input: " , routed_input )
328
+
329
+ # TODO: remove this line, this is a temporary test
330
+ routed_input = (routed_input .to (torch .float32 ) * top_scores .reshape (- 1 , 1 )).to (
331
+ x .dtype
332
+ )
332
333
333
334
if self .use_grouped_mm :
334
335
# NOTE: In order to use torch._grouped_mm, we need to make sure
@@ -350,28 +351,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
350
351
num_local_tokens_per_expert ,
351
352
self .experts .num_experts ,
352
353
1 ,
353
- token_indices [0 ] + self .experts .num_experts * ALIGN_SIZE_M ,
354
+ token_indices . shape [0 ] + self .experts .num_experts * ALIGN_SIZE_M ,
354
355
ALIGN_SIZE_M ,
355
356
)
356
- token_indices = torch .vstack (
357
- (token_indices , token_indices .new_zeros ((dim )))
358
- )
359
- token_indices = token_indices [permuted_indices , :]
357
+
360
358
routed_input = torch .vstack ((routed_input , routed_input .new_zeros ((dim ))))
359
+ input_shape = routed_input .shape
361
360
routed_input = routed_input [permuted_indices , :]
362
361
else :
363
362
# NOTE: this would incur a synchronization between device and host
364
363
num_local_tokens_per_expert = num_local_tokens_per_expert .tolist ()
364
+ input_shape , permuted_indices = None , None
365
365
366
- print ("Num local tokens per expert: " , num_local_tokens_per_expert )
367
366
# shape (bs*slen*top_k, dim)
368
367
routed_output = self .experts (
369
368
routed_input , num_local_tokens_per_expert
370
369
) # torch.Size([16384(bsz), 256])
371
- print ("Routed output shape: " , routed_output .shape )
372
- routed_output = (routed_output .to (torch .float32 ) * top_scores .unsqueeze (- 1 )).to (
373
- x .dtype
374
- )
370
+
371
+ routed_output_unpermuted = routed_output .new_empty (input_shape )
372
+ routed_output_unpermuted [permuted_indices , :] = routed_output
373
+ routed_output = routed_output_unpermuted [:- 1 ]
374
+
375
+ # TODO: Use this line instead if routed_input*top_scores, need to pad top_scores to be multiple of 16
376
+ # routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(
377
+ # x.dtype
378
+ # )
375
379
376
380
# shared expert
377
381
if self .shared_expert is not None :
@@ -381,10 +385,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
381
385
else :
382
386
out = torch .zeros_like (x .reshape (bs * slen , dim ))
383
387
384
- print (
385
- "Out shape: " , out .shape , out .grad .shape if out .grad is not None else None
386
- )
387
-
388
388
out = out .scatter_add (dim = 0 , index = token_indices , src = routed_output )
389
389
out = out .reshape (bs , slen , dim )
390
390
return out
0 commit comments