@@ -201,17 +201,20 @@ def forward(
201
201
min = 0 ,
202
202
max = self .num_experts ,
203
203
)
204
+
205
+ # Reorder the token indices to match the order of the experts
204
206
# token_indices_experts_sorted shape (bs*slen*top_k,)
205
207
token_indices_experts_sorted = torch .argsort (
206
208
selected_experts_indices .view (- 1 ), stable = True
207
209
)
210
+
211
+ # reorder the scores to match the order of the token indices
208
212
top_scores = top_scores .view (- 1 )[token_indices_experts_sorted ]
209
213
token_indices_experts_sorted = token_indices_experts_sorted // self .top_k
210
214
211
215
top_scores = (
212
216
top_scores * self .route_sclaing_factor
213
217
) # must multiply the scaling factor
214
- print ("In TokenChoiceTopKRouter, top_scores shape: " , top_scores )
215
218
return top_scores , token_indices_experts_sorted , num_local_tokens_per_expert
216
219
217
220
def init_weights (self , init_std : float ):
@@ -292,7 +295,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
292
295
Returns:
293
296
out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``.
294
297
"""
295
- print ("In MoE input, x shape: " , x )
296
298
bs , slen , dim = x .shape
297
299
298
300
# top_scores and selected_indices shape (bs*slen*top_k,)
@@ -303,15 +305,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
303
305
num_local_tokens_per_expert ,
304
306
) = self .router (x .reshape (bs * slen , dim ), self .expert_bias )
305
307
306
- # print(
307
- # "In MoE, top_scores shape: ",
308
- # top_scores.shape,
309
- # "token_indices: ",
310
- # token_indices.shape,
311
- # "num_local_tokens: ",
312
- # num_local_tokens_per_expert.shape,
313
- # )
314
-
315
308
# will be used to update the expert bias for load balancing
316
309
self .tokens_per_expert += num_local_tokens_per_expert
317
310
@@ -324,12 +317,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
324
317
dim = 0 ,
325
318
index = token_indices ,
326
319
)
327
- print ("Routed input: " , routed_input )
328
-
329
- # TODO: remove this line, this is a temporary test
330
- routed_input = (routed_input .to (torch .float32 ) * top_scores .reshape (- 1 , 1 )).to (
331
- x .dtype
332
- )
333
320
334
321
if self .use_grouped_mm :
335
322
# NOTE: In order to use torch._grouped_mm, we need to make sure
@@ -361,30 +348,30 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
361
348
else :
362
349
# NOTE: this would incur a synchronization between device and host
363
350
num_local_tokens_per_expert = num_local_tokens_per_expert .tolist ()
364
- input_shape , permuted_indices = None , None
351
+ permuted_indices , input_shape = None , None
365
352
366
353
# shape (bs*slen*top_k, dim)
367
- routed_output = self .experts (
368
- routed_input , num_local_tokens_per_expert
369
- ) # torch.Size([16384(bsz), 256])
354
+ routed_output = self .experts (routed_input , num_local_tokens_per_expert )
370
355
371
- routed_output_unpermuted = routed_output .new_empty (input_shape )
372
- routed_output_unpermuted [permuted_indices , :] = routed_output
373
- routed_output = routed_output_unpermuted [:- 1 ]
356
+ if self .use_grouped_mm :
357
+ # NOTE: Reverese the permutation to get the original order as inputs
358
+ routed_output_unpermuted = routed_output .new_empty (input_shape )
359
+ routed_output_unpermuted [permuted_indices , :] = routed_output
360
+ routed_output = routed_output_unpermuted [:- 1 ] # remove padding
374
361
375
- # TODO: Use this line instead if routed_input*top_scores, need to pad top_scores to be multiple of 16
376
- # routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to(
377
- # x.dtype
378
- # )
362
+ routed_output = (routed_output .to (torch .float32 ) * top_scores .unsqueeze (- 1 )).to (
363
+ x .dtype
364
+ )
379
365
380
366
# shared expert
381
367
if self .shared_expert is not None :
382
368
out = self .shared_expert (x .reshape (1 , bs * slen , dim )).reshape (
383
369
bs * slen , dim
384
- ) # torch.Size([16384, 256]) None
370
+ )
385
371
else :
386
372
out = torch .zeros_like (x .reshape (bs * slen , dim ))
387
373
374
+ # Accumulate multiple expert results becase each token can be routed to multiple experts
388
375
out = out .scatter_add (dim = 0 , index = token_indices , src = routed_output )
389
376
out = out .reshape (bs , slen , dim )
390
377
return out
0 commit comments