Skip to content

Commit 0897ccc

Browse files
authored
Merge pull request #101 from raindaywhu/br_main_into_eplb_wjh
optimize calculate moe load
2 parents fc88c4b + 9c329ed commit 0897ccc

File tree

3 files changed

+38
-32
lines changed

3 files changed

+38
-32
lines changed

vllm_ascend/eplb/adaptor/vllm_adaptor.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(self, model, **args):
3838
self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers
3939
self.global_expert_num = self.model.config.n_routed_experts
4040

41-
41+
4242
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 is supported here
4343
self.expert_weight_names = ["w13_weight", "w2_weight", "w13_weight_scale", "w13_weight_offset",
4444
"w2_weight_scale", "w2_weight_offset"]
@@ -62,6 +62,8 @@ def __init__(self, model, **args):
6262
self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] =\
6363
self.model.get_log2phy_map(self.num_dense_layers + layer_idx)
6464

65+
self.all_topk_ids = []
66+
6567
def init_buffer_tensor(self, num_buffer_tensor):
6668
for name in self.expert_weight_names:
6769
complete_name = "model.layers." + str(self.num_dense_layers) + ".mlp.experts." + name
@@ -82,39 +84,36 @@ def init_expert_param_per_layer(self):
8284
for name in self.expert_weight_names]
8385
)
8486

85-
def get_rank_expert_workload(
86-
self,
87-
num_moe_layers: int,
88-
dummy_run = False
89-
) -> torch.Tensor:
90-
91-
all_topk_ids = [self.model.get_topk_ids(i) for i in range(num_moe_layers)]
92-
stacked = torch.stack(all_topk_ids, dim=0)
93-
L, B, K = stacked.shape
94-
N = B * K
95-
device = stacked.device
96-
G = self.global_expert_num
87+
def collect_topk_ids(self, dummy_run=False):
88+
if dummy_run:
89+
return
90+
self.all_topk_ids.append(self.model.get_all_topk_ids(self.num_moe_layers))
9791

98-
if not hasattr(self, "cum_moe_load") or self.cum_moe_load is None:
99-
self.cum_moe_load = torch.zeros((L, G),
100-
dtype=torch.int64,
101-
device=device)
92+
def get_rank_expert_workload(self) -> torch.Tensor:
10293

103-
if dummy_run:
104-
return self.cum_moe_load
94+
device = self.all_topk_ids[0][0].device
95+
flat_list_per_layer = [[] for _ in range(self.num_moe_layers)]
10596

106-
ids1d = stacked.view(-1).to(torch.int64)
97+
for period_data in self.all_topk_ids:
98+
for l in range(self.num_moe_layers):
99+
t = period_data[l]
100+
flat_list_per_layer[l].append(t.reshape(-1))
107101

108-
row_idx = torch.arange(L, device=device).repeat_interleave(N)
102+
index_2d = torch.nn.utils.rnn.pad_sequence(
103+
[torch.cat(flat_list_per_layer[l]) for l in range(self.num_moe_layers)],
104+
batch_first=True, padding_value=-1
105+
).to(device)
109106

110-
combined = row_idx * G + ids1d
107+
mask = index_2d != -1
108+
index_2d = index_2d.masked_select(mask).reshape(self.num_moe_layers, -1)
109+
src_2d = torch.ones_like(index_2d, dtype=torch.int64)
111110

112-
counts = torch.bincount(combined, minlength=L * G)
113-
workload = counts.view(L, G)
111+
moe_load = torch.zeros((self.num_moe_layers), self.global_expert_num,
112+
dtype=torch.int64, device=device)
113+
moe_load.scatter_add_(dim=1, index=index_2d, src=src_2d)
114114

115-
self.cum_moe_load.add_(workload)
116-
117-
return self.cum_moe_load
115+
self.all_topk_ids = []
116+
return moe_load
118117

119118
def get_init_expert_map(self, num_moe_layers):
120119
expert_map = self.model.get_all_expert_map(num_moe_layers)

vllm_ascend/eplb/eplb_updator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def init_eplb(self, expert_map_path):
5757
self.cur_iterations: torch.int64 = 0
5858

5959
self.wait_worker_iterations: torch.int64 = 0
60-
self.num_wait_worker_iterations: torch.int64 = 10
60+
self.num_wait_worker_iterations: torch.int64 = 20
6161

6262
self.planner_block_queue = Queue()
6363
self.block_update_queue = Queue(maxsize=1)
@@ -131,11 +131,11 @@ def forward_before(self):
131131
self.eplb_loader.asyn_expert_weight_transfer(self.reqs)
132132

133133
def forward_end(self,dummy_run=False):
134-
self.adaptor.get_rank_expert_workload(self.num_moe_layers,dummy_run)
134+
self.adaptor.collect_topk_ids(dummy_run)
135135
if not self.update_in_flight:
136136
load_gather_iteration, update_iteration = self.get_update_iteration()
137137
if load_gather_iteration:
138-
moe_load = self.compute_and_set_moe_load(dummy_run)
138+
moe_load = self.compute_and_set_moe_load()
139139
if update_iteration:
140140
self.wakeup_eplb_worker()
141141
self.update_in_flight = True
@@ -148,7 +148,7 @@ def forward_end(self,dummy_run=False):
148148
self.eplb_loader.update_expert_map_and_weight(self.reqs, self.redundant_enable)
149149

150150
def compute_and_set_moe_load(self,dummy_run=False):
151-
local_load = self.adaptor.get_rank_expert_workload(self.num_moe_layers,dummy_run)
151+
local_load = self.adaptor.get_rank_expert_workload()
152152
self._gather_buffer = None
153153
if dist.is_initialized():
154154
self.world_size = dist.get_world_size()
@@ -173,7 +173,7 @@ def compute_and_set_moe_load(self,dummy_run=False):
173173
def warm_up_eplb(self):
174174

175175
self.get_init_expert_map()
176-
176+
self.adaptor.collect_topk_ids(dummy_run=False)
177177
self.compute_and_set_moe_load()
178178

179179
src_tensor = torch.empty((1,), device=self.device)

vllm_ascend/models/deepseek_v2.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,5 +772,12 @@ def get_all_expert_map(self,num_moe_layers):
772772
def get_topk_ids(self,layer_id):
773773
return self.model.layers[layer_id+3].mlp.experts.topk_ids
774774

775+
def get_all_topk_ids(self,num_moe_layers):
776+
all_topk_id = []
777+
for layer_id in range(num_moe_layers):
778+
load_tensor = self.get_topk_ids(layer_id)
779+
all_topk_id.append(load_tensor)
780+
return all_topk_id
781+
775782
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
776783
pass

0 commit comments

Comments
 (0)