@@ -517,8 +517,8 @@ def fused_experts_with_all2all_buffer(
517
517
dtype = expert_idx_buffer_scatter .dtype ,
518
518
device = expert_idx_buffer_scatter .device ,
519
519
)
520
- non_pad_len = torch .sum (( expert_idx_buffer_scatter
521
- != global_num_experts ).to (torch .int32 ))
520
+ non_pad_len = torch .sum (
521
+ ( expert_idx_buffer_scatter != global_num_experts ).to (torch .int32 ))
522
522
hidden_states_pad_idx [expert_idx_buffer_scatter != global_num_experts ] = (
523
523
torch .arange (
524
524
non_pad_len ,
@@ -580,8 +580,8 @@ def fused_experts_with_all2all_buffer(
580
580
dist .all_to_all_single (hidden_states_gatter ,
581
581
hidden_states_scatter ,
582
582
group = ep_group .device_group )
583
- hidden_states_gatter = hidden_states_gatter [expert_idx_buffer_scatter !=
584
- global_num_experts ]
583
+ hidden_states_gatter = hidden_states_gatter [
584
+ expert_idx_buffer_scatter != global_num_experts ]
585
585
if hidden_states_gatter .shape [0 ] != row_idx_len :
586
586
hidden_states = torch .zeros (
587
587
(row_idx_len , hidden_states .shape [1 ]),
@@ -776,10 +776,9 @@ def fused_experts(
776
776
# This created multiple NaN and index_add_ will mix them up which harms accuracy
777
777
# remove this mask and filter after it being fixed
778
778
num_valid_tokens = mask .sum ()
779
- valid_token_mask = (torch .arange (0 ,
780
- sorted_token_indices .shape [0 ],
781
- device = device ).unsqueeze (1 )
782
- < num_valid_tokens )
779
+ valid_token_mask = (torch .arange (
780
+ 0 , sorted_token_indices .shape [0 ], device = device ).unsqueeze (1 ) <
781
+ num_valid_tokens )
783
782
valid_output = torch .where (
784
783
valid_token_mask , weighted_down_out ,
785
784
torch .zeros_like (weighted_down_out )).to (dtype )
0 commit comments