@@ -38,7 +38,7 @@ def __init__(self, model, **args):
38
38
self .num_moe_layers = self .model .config .num_hidden_layers - self .num_dense_layers
39
39
self .global_expert_num = self .model .config .n_routed_experts
40
40
41
-
41
+
42
42
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 is supported here
43
43
self .expert_weight_names = ["w13_weight" , "w2_weight" , "w13_weight_scale" , "w13_weight_offset" ,
44
44
"w2_weight_scale" , "w2_weight_offset" ]
@@ -62,6 +62,8 @@ def __init__(self, model, **args):
62
62
self .log2phy_map_per_layer [self .num_dense_layers + layer_idx ] = \
63
63
self .model .get_log2phy_map (self .num_dense_layers + layer_idx )
64
64
65
+ self .all_topk_ids = []
66
+
65
67
def init_buffer_tensor (self , num_buffer_tensor ):
66
68
for name in self .expert_weight_names :
67
69
complete_name = "model.layers." + str (self .num_dense_layers ) + ".mlp.experts." + name
@@ -82,39 +84,36 @@ def init_expert_param_per_layer(self):
82
84
for name in self .expert_weight_names ]
83
85
)
84
86
85
- def get_rank_expert_workload (
86
- self ,
87
- num_moe_layers : int ,
88
- dummy_run = False
89
- ) -> torch .Tensor :
90
-
91
- all_topk_ids = [self .model .get_topk_ids (i ) for i in range (num_moe_layers )]
92
- stacked = torch .stack (all_topk_ids , dim = 0 )
93
- L , B , K = stacked .shape
94
- N = B * K
95
- device = stacked .device
96
- G = self .global_expert_num
87
+ def collect_topk_ids (self , dummy_run = False ):
88
+ if dummy_run :
89
+ return
90
+ self .all_topk_ids .append (self .model .get_all_topk_ids (self .num_moe_layers ))
97
91
98
- if not hasattr (self , "cum_moe_load" ) or self .cum_moe_load is None :
99
- self .cum_moe_load = torch .zeros ((L , G ),
100
- dtype = torch .int64 ,
101
- device = device )
92
+ def get_rank_expert_workload (self ) -> torch .Tensor :
102
93
103
- if dummy_run :
104
- return self .cum_moe_load
94
+ device = self . all_topk_ids [ 0 ][ 0 ]. device
95
+ flat_list_per_layer = [[] for _ in range ( self .num_moe_layers )]
105
96
106
- ids1d = stacked .view (- 1 ).to (torch .int64 )
97
+ for period_data in self .all_topk_ids :
98
+ for l in range (self .num_moe_layers ):
99
+ t = period_data [l ]
100
+ flat_list_per_layer [l ].append (t .reshape (- 1 ))
107
101
108
- row_idx = torch .arange (L , device = device ).repeat_interleave (N )
102
+ index_2d = torch .nn .utils .rnn .pad_sequence (
103
+ [torch .cat (flat_list_per_layer [l ]) for l in range (self .num_moe_layers )],
104
+ batch_first = True , padding_value = - 1
105
+ ).to (device )
109
106
110
- combined = row_idx * G + ids1d
107
+ mask = index_2d != - 1
108
+ index_2d = index_2d .masked_select (mask ).reshape (self .num_moe_layers , - 1 )
109
+ src_2d = torch .ones_like (index_2d , dtype = torch .int64 )
111
110
112
- counts = torch .bincount (combined , minlength = L * G )
113
- workload = counts .view (L , G )
111
+ moe_load = torch .zeros ((self .num_moe_layers ), self .global_expert_num ,
112
+ dtype = torch .int64 , device = device )
113
+ moe_load .scatter_add_ (dim = 1 , index = index_2d , src = src_2d )
114
114
115
- self .cum_moe_load .add_ (workload )
116
-
117
- return self .cum_moe_load
115
+ self .all_topk_ids = []
116
+ return moe_load
118
117
119
118
def get_init_expert_map (self , num_moe_layers ):
120
119
expert_map = self .model .get_all_expert_map (num_moe_layers )
0 commit comments