Skip to content

Commit da49def

Browse files
author
lt
committed
merge from remote main
2 parents 45766f6 + 96fe998 commit da49def

File tree

5 files changed

+80
-96
lines changed

5 files changed

+80
-96
lines changed

vllm_ascend/eplb/adaptor/vllm_adaptor.py

Lines changed: 39 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(self, model, **args):
3838
self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers
3939
self.global_expert_num = self.model.config.n_routed_experts
4040

41-
41+
4242
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 is supported here
4343
self.expert_weight_names = ["w13_weight", "w2_weight", "w13_weight_scale", "w13_weight_offset",
4444
"w2_weight_scale", "w2_weight_offset"]
@@ -62,6 +62,8 @@ def __init__(self, model, **args):
6262
self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] =\
6363
self.model.get_log2phy_map(self.num_dense_layers + layer_idx)
6464

65+
self.all_topk_ids = []
66+
6567
def init_buffer_tensor(self, num_buffer_tensor):
6668
for name in self.expert_weight_names:
6769
complete_name = "model.layers." + str(self.num_dense_layers) + ".mlp.experts." + name
@@ -82,39 +84,43 @@ def init_expert_param_per_layer(self):
8284
for name in self.expert_weight_names]
8385
)
8486

85-
def get_rank_expert_workload(
86-
self,
87-
num_moe_layers: int,
88-
dummy_run = False
89-
) -> torch.Tensor:
90-
91-
all_topk_ids = [self.model.get_topk_ids(i) for i in range(num_moe_layers)]
92-
stacked = torch.stack(all_topk_ids, dim=0)
93-
L, B, K = stacked.shape
94-
N = B * K
95-
device = stacked.device
96-
G = self.global_expert_num
97-
98-
if not hasattr(self, "cum_moe_load") or self.cum_moe_load is None:
99-
self.cum_moe_load = torch.zeros((L, G),
100-
dtype=torch.int64,
101-
device=device)
102-
87+
def collect_topk_ids(self, dummy_run=False):
10388
if dummy_run:
104-
return self.cum_moe_load
105-
106-
ids1d = stacked.view(-1).to(torch.int64)
107-
108-
row_idx = torch.arange(L, device=device).repeat_interleave(N)
109-
110-
combined = row_idx * G + ids1d
111-
112-
counts = torch.bincount(combined, minlength=L * G)
113-
workload = counts.view(L, G)
114-
115-
self.cum_moe_load.add_(workload)
116-
117-
return self.cum_moe_load
89+
return
90+
self.all_topk_ids.append(self.model.get_all_topk_ids(self.num_moe_layers))
91+
92+
def get_rank_expert_workload(self) -> torch.Tensor:
93+
device = self.all_topk_ids[0][0].device
94+
if not hasattr(self, "moe_load"):
95+
self.moe_load = torch.zeros(
96+
(self.num_moe_layers), self.global_expert_num,
97+
dtype=torch.int64,
98+
device=self.all_topk_ids[0][0].device,
99+
)
100+
else:
101+
self.moe_load.zero_()
102+
# pass
103+
flat_list_per_layer = [[] for _ in range(self.num_moe_layers)]
104+
105+
for period_data in self.all_topk_ids:
106+
for l in range(self.num_moe_layers):
107+
t = period_data[l]
108+
flat_list_per_layer[l].append(t.reshape(-1))
109+
110+
index_2d = torch.nn.utils.rnn.pad_sequence(
111+
[torch.cat(flat_list_per_layer[l]) for l in range(self.num_moe_layers)],
112+
batch_first=True, padding_value=-1
113+
).to(device)
114+
115+
mask = index_2d != -1
116+
index_2d = index_2d.masked_select(mask).reshape(self.num_moe_layers, -1)
117+
src_2d = torch.ones_like(index_2d, dtype=torch.int64)
118+
119+
self.moe_load.scatter_add_(dim=1, index=index_2d, src=src_2d)
120+
121+
if self.all_topk_ids:
122+
self.all_topk_ids[:] = self.all_topk_ids[-1:]
123+
return self.moe_load
118124

119125
def get_init_expert_map(self, num_moe_layers):
120126
expert_map = self.model.get_all_expert_map(num_moe_layers)
@@ -136,32 +142,6 @@ def get_init_expert_map(self, num_moe_layers):
136142

137143
return all_expert_maps
138144

139-
def local2global(self,
140-
placement_local: torch.Tensor
141-
) -> torch.Tensor:
142-
143-
L, G, E_local = placement_local.shape
144-
device = placement_local.device
145-
146-
max_id = torch.max(placement_local)
147-
E_global = (max_id + 1).item() if max_id >= 0 else 0
148-
149-
if E_global == 0:
150-
return torch.empty((L, G, 0), dtype=torch.long, device=device)
151-
152-
placement_global = torch.full((L, G, E_global),
153-
fill_value=-1,
154-
dtype=torch.long,
155-
device=device)
156-
157-
valid = placement_local >= 0
158-
l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True)
159-
gid_idx = placement_local[l_idx, g_idx, slot_idx]
160-
161-
placement_global[l_idx, g_idx, gid_idx] = slot_idx
162-
163-
return placement_global
164-
165145
def get_init_expert_map_from_file(self, num_moe_layers, expert_map_path):
166146

167147
try:

vllm_ascend/eplb/core/worker/eplb_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def do_update(self):
6262

6363
#根据负载信息,获取更新后的专家表
6464
load_info, old_placement = self.global2local(load_info, self.old_expert_maps, self.num_local_experts)
65+
self.shared_dict["load_info"] = load_info
6566
changed, priority, new_placement = self.calculate_rebalance_experts(load_info, old_placement)
6667

6768
if not torch.is_tensor(new_placement):

vllm_ascend/eplb/eplb_updator.py

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@
1414
# limitations under the License.
1515
# This file is a part of the vllm-ascend project.
1616
#
17+
1718
import torch
19+
from typing import Dict, List
1820
import torch.distributed as dist
1921
import vllm.envs as envs
2022
from multiprocessing import Queue, Manager
2123

2224
from vllm.logger import logger
2325
from vllm_ascend.eplb.core.worker.eplb_worker import EplbProcess
2426
from vllm_ascend.eplb.core.loader.device_transfer_loader import D2DExpertWeightLoader
27+
from vllm_ascend.eplb.tool.eplb_utils import ExpertMapUtils
2528

2629
class EplbUpdator:
2730

@@ -32,6 +35,7 @@ def set_adaptor(self, adaptor):
3235
self.adaptor = adaptor
3336
self.eplb_loader = D2DExpertWeightLoader(eplb_adaptor=self.adaptor)
3437
self.num_moe_layers = self.adaptor.num_moe_layers
38+
self.global_expert_num = self.adaptor.global_expert_num
3539

3640
def init_eplb(self, expert_map_path):
3741
self.num_expert_load_gather = 10
@@ -57,7 +61,7 @@ def init_eplb(self, expert_map_path):
5761
self.cur_iterations: torch.int64 = 0
5862

5963
self.wait_worker_iterations: torch.int64 = 0
60-
self.num_wait_worker_iterations: torch.int64 = 10
64+
self.num_wait_worker_iterations: torch.int64 = 20
6165

6266
self.planner_block_queue = Queue()
6367
self.block_update_queue = Queue(maxsize=1)
@@ -69,7 +73,9 @@ def init_eplb(self, expert_map_path):
6973
# 热度负载信息 [num_layers, world_size, num_experts]
7074
"moe_load": None,
7175
# 所有的专家表[num_layers, world_size, num_experts]
72-
"expert_maps": None
76+
"expert_maps": None,
77+
# 热度负载信息 [num_layers, world_size, local_num_experts]
78+
"load_info": None,
7379
})
7480

7581
self.eplb = EplbProcess(
@@ -125,30 +131,31 @@ def forward_before(self):
125131
self.weight_update_counter = 0
126132
self.update_in_flight = False
127133
self.update_info_all = []
128-
129134
# set asynchronous stream for d2d expert weight update
130135
self.reqs = []
131136
self.eplb_loader.asyn_expert_weight_transfer(self.reqs)
132137

138+
133139
def forward_end(self,dummy_run=False):
134-
self.adaptor.get_rank_expert_workload(self.num_moe_layers,dummy_run)
135-
if not self.update_in_flight:
136-
load_gather_iteration, update_iteration = self.get_update_iteration()
137-
if load_gather_iteration:
138-
moe_load = self.compute_and_set_moe_load(dummy_run)
139-
if update_iteration:
140-
self.wakeup_eplb_worker()
141-
self.update_in_flight = True
142-
self.wait_worker_iterations = 0
143-
self.weight_loading = False
144-
145-
if self.update_in_flight:
146-
self.wait_worker_iterations = self.wait_worker_iterations + 1
147-
148-
self.eplb_loader.update_expert_map_and_weight(self.reqs, self.redundant_enable)
140+
self.adaptor.collect_topk_ids(dummy_run)
141+
if not self.update_in_flight:
142+
load_gather_iteration, update_iteration = self.get_update_iteration()
143+
if load_gather_iteration:
144+
moe_load = self.compute_and_set_moe_load()
145+
if update_iteration:
146+
self.wakeup_eplb_worker()
147+
self.update_in_flight = True
148+
self.wait_worker_iterations = 0
149+
self.weight_loading = False
150+
151+
if self.update_in_flight:
152+
self.wait_worker_iterations = self.wait_worker_iterations + 1
153+
154+
self.eplb_loader.update_expert_map_and_weight(self.reqs, self.redundant_enable)
149155

150156
def compute_and_set_moe_load(self,dummy_run=False):
151-
local_load = self.adaptor.get_rank_expert_workload(self.num_moe_layers,dummy_run)
157+
local_load = self.adaptor.get_rank_expert_workload()
158+
152159
self._gather_buffer = None
153160
if dist.is_initialized():
154161
self.world_size = dist.get_world_size()
@@ -173,7 +180,7 @@ def compute_and_set_moe_load(self,dummy_run=False):
173180
def warm_up_eplb(self):
174181

175182
self.get_init_expert_map()
176-
183+
self.adaptor.collect_topk_ids(dummy_run=False)
177184
self.compute_and_set_moe_load()
178185

179186
src_tensor = torch.empty((1,), device=self.device)
@@ -228,29 +235,18 @@ def unpack_update_batch(self, packed_update_info):
228235
]
229236
return recovered
230237

231-
def get_expert_load(self) -> str:
232-
233-
# todo wjh 给到返回值
234-
# return self.shared_dict['moe_load']
235-
# mock json_str
236-
experts_load = ('{\"expert_load\":['
237-
'{\"ip\":\"141.xxx.xxx.181\",'
238-
'\"node_0\":'
239-
'{\"card_0\":'
240-
'[{\"layer_4\":{\"expert_0\":3,\"expert_2\":1}},{\"layer_5\":{\"expert_0\":3,\"expert_2\":1}}],'
241-
'\"card_1\":[{\"layer_4\":{\"expert_1\":3,\"expert_3\":1},\"layer_5\":{\"expert_0\":3,\"'
242-
'expert_2\":1}}]}},{\"ip\":\"141.xxx.xxx.177\",\"node_0\":{\"card_0\":[{\"layer_4\":'
243-
'{\"expert_0\":3,\"expert_2\":1}},{\"layer_5\":{\"expert_0\":3,\"expert_2\":1}}],'
244-
'\"card_1\":[{\"layer_4\":{\"expert_1\":3,\"expert_3\":1}}]}}]}')
245-
return experts_load
238+
def get_expert_load(self) -> torch.Tensor:
239+
load_info = self.shared_dict["load_info"] # Tensor [L, W, local_experts_num]
240+
logger.info(f"lt -- load_info {load_info=}...")
241+
return load_info
242+
246243

247244
def update_expert_load_statistical_period(self, num_expert_load_gather: int, num_iterations: int):
248245
logger.info(f" start update {self.num_expert_load_gather=}, {self.num_iterations}...")
249246
self.num_expert_load_gather = num_expert_load_gather
250247
self.num_iterations = num_iterations
251248
logger.info(f" update {self.num_expert_load_gather=}, {self.num_iterations} success...")
252249

253-
254250
def shutdown(self):
255251
"""
256252
Clean up the EPLB process.

vllm_ascend/models/deepseek_v2.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,5 +772,12 @@ def get_all_expert_map(self,num_moe_layers):
772772
def get_topk_ids(self,layer_id):
773773
return self.model.layers[layer_id+3].mlp.experts.topk_ids
774774

775+
def get_all_topk_ids(self,num_moe_layers):
776+
all_topk_id = []
777+
for layer_id in range(num_moe_layers):
778+
load_tensor = self.get_topk_ids(layer_id)
779+
all_topk_id.append(load_tensor)
780+
return all_topk_id
781+
775782
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
776783
pass

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ def apply(
678678
global_redundant_expert_num=global_redundant_expert_num,
679679
shared_experts=shared_experts,
680680
**kwargs), topk_ids
681-
elif self.torchair_graph_enabled or self.ep_group.world_size == 1:
681+
elif fused_moe_state == FusedMoEState.AllGather:
682682
return fused_experts(hidden_states=x,
683683
w1=layer.w13_weight,
684684
w1_scale=layer.w13_weight_scale,

0 commit comments

Comments
 (0)