Skip to content

Commit 83f2d51

Browse files
author
lt
committed
Merge branch 'br_main_into_eplb' of https://github.com/raindaywhu/vllm-ascend into br_main_into_eplb
2 parents f6830d4 + 1a8d238 commit 83f2d51

File tree

6 files changed

+77
-74
lines changed

6 files changed

+77
-74
lines changed

vllm_ascend/eplb/adaptor/vllm_adaptor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,13 @@ def local2global(self,
152152
return placement_global
153153

154154
def get_init_expert_map_from_file(self, num_moe_layers, expert_map_path):
155-
if os.path.exists(expert_map_path):
155+
156+
try:
156157
expert_map_tensor, layers_num, ranks_num = self._expert_file_to_tensor(expert_map_path)
157158
expert_map_all = self.local2global(expert_map_tensor)
158-
else:
159+
except (TypeError, FileNotFoundError, OSError):
159160
expert_map_all = self.determine_expert_map_all()
161+
160162
for layer_idx in range(num_moe_layers):
161163
self.expert_map_per_layer_cpu[layer_idx+3] = \
162164
expert_map_all[layer_idx][self.rank_id]

vllm_ascend/eplb/core/policy/dynamic_ep_v2.py

Lines changed: 19 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -307,61 +307,6 @@ def calculate_initial_imbalance(global_deployment, new_layer_workloads):
307307

308308
return layer_imbalance
309309

310-
def rebalance_experts(self, current_expert_table, expert_workload):
311-
312-
info = DynamicTable()
313-
info.workload_table = np.array(expert_workload)
314-
info.placement_table = np.array(current_expert_table)
315-
layer_num, num_npus, experts_per_npu = info.workload_table.shape
316-
expert_ids, counts = np.unique(info.placement_table[0], return_counts=True)
317-
num_redundancy_expert = self.get_redundant_num(num_npus, counts)
318-
num_original_expert = len(expert_ids)
319-
layer_workloads = self.add_redundant(info.placement_table, info.workload_table, num_original_expert)
320-
max_heat_per_layer_before = self.calculate_max_heat_per_layer(info.workload_table, layer_num)
321-
npu_heat_all_origin = sum(max_heat_per_layer_before)
322-
323-
# 计算负载均衡,部署冗余专家
324-
layer_num = layer_workloads.shape[0]
325-
expert_num = layer_workloads.shape[1]
326-
# 校验专家数量、卡数量、冗余专家数量不能超过卡数量
327-
if num_original_expert != expert_num:
328-
raise ValueError(f"原始专家数量 {num_original_expert} 必须等于 expert_num {expert_num}")
329-
330-
if num_npus <= 0:
331-
raise ValueError("NPUs 数量必须大于 0")
332-
333-
if num_npus < num_redundancy_expert:
334-
raise ValueError(f"NPUs 数量 {num_npus} 必须大于或等于冗余专家数量 {num_redundancy_expert}")
335-
336-
# 每个卡部署的专家数量 一个冗余专家
337-
global_deployment = [[[] for _ in range(num_npus)] for _ in range(layer_num)]
338-
# 遍历获得每一层的放置策略,考虑计算均衡
339-
max_heat_per_layer_after = np.zeros([layer_num])
340-
for layer in range(layer_num):
341-
# 获取当前层专家ID和对应负载,负载需要进行正则化处理, 每个卡加一个冗余专家
342-
weights = np.zeros((expert_num,), dtype='object')
343-
for expert_id, workload_weight in enumerate(layer_workloads[layer]):
344-
weights[expert_id] = (expert_id, workload_weight)
345-
346-
# 获取每一层全局计算均衡的放置策略
347-
result, layer_deployment = self.original_compute_balanced_pack_redundancy(
348-
weights, num_npus, num_redundancy_expert
349-
)
350-
global_deployment[layer] = layer_deployment
351-
max_heat_per_layer_after[layer] = max(result, key=lambda x: x['total_weight'])['total_weight']
352-
353-
# 获取层优先级
354-
layer_changed_ratio = []
355-
for layer_idx in range(layer_num):
356-
layer_changed_ratio.append(max_heat_per_layer_after[layer_idx] / max_heat_per_layer_before[layer_idx])
357-
358-
per_layer_priority = np.argsort(layer_changed_ratio)
359-
npu_heat_all_after = sum(max_heat_per_layer_after)
360-
361-
change = 0
362-
363-
return change, per_layer_priority, np.array(global_deployment).tolist()
364-
365310
@staticmethod
366311
def compute_redundant_assignments(base_experts, num_redundant_experts, num_experts):
367312
"""
@@ -845,6 +790,25 @@ def rebalance_experts(self, current_expert_table, expert_workload):
845790
num_node, num_npus, False, ave_workload,
846791
0.05, num_redundancy_expert)
847792

793+
# To guarantee there is no expert movement inside a NPU
794+
start_physical_idx = 1 if num_redundancy_expert else 0
795+
for rank in range(num_npus):
796+
physical_expert = start_physical_idx
797+
while physical_expert in range(start_physical_idx, experts_per_npu):
798+
# skip the expert which is moved into this rank
799+
if global_deployment[layer][rank][physical_expert] not in current_expert_table[layer, rank, :]:
800+
physical_expert += 1
801+
continue
802+
803+
if global_deployment[layer][rank][physical_expert] != current_expert_table[layer][rank][physical_expert]:
804+
right_idx = np.where(current_expert_table[layer][rank] == global_deployment[layer][rank][physical_expert])[0][0]
805+
# exchange expert with the expert on the right physical index
806+
tempt = global_deployment[layer][rank][right_idx]
807+
global_deployment[layer][rank][right_idx] = global_deployment[layer][rank][physical_expert]
808+
global_deployment[layer][rank][physical_expert] = tempt
809+
else:
810+
physical_expert += 1
811+
848812
for device_id in range(num_npus):
849813
com_between_devices[device_id] = {int(key): int(value) for key, value in
850814
com_between_devices[device_id].items()}

vllm_ascend/eplb/core/policy/policy_factory.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,18 @@ class PolicyFactory:
1010
@staticmethod
1111
def generate_policy(policy_type: int, config: DynamicConfig) -> EplbPolicy:
1212
policy = {
13+
# Constraint applying Dynamic EPLB policy V2:
14+
# If there exists redundant expert:
15+
# only one redundant expert can be placed in one NPU and its physical expert index must be 0
16+
1317
# Applying bipartite d2d expert weight update composing
1418
0:MockLoadBalance, # MockLoadBalance
1519
1:DynamicEplb, # Dynamic EPLB policy
1620
2:DynamicEplbV2, # Dynamic EPLB policy V2
1721

1822
# Applying greedy d2d expert weight update composing
19-
4:MockLoadBalance, # MockLoadBalance
20-
5:DynamicEplb, # Dynamic EPLB policy
21-
6:DynamicEplbV2, # Dynamic EPLB policy
23+
3:MockLoadBalance, # MockLoadBalance
24+
4:DynamicEplb, # Dynamic EPLB policy
25+
5:DynamicEplbV2, # Dynamic EPLB policy V2
2226
}
2327
return policy.get(policy_type, MockLoadBalance)(config)

vllm_ascend/eplb/eplb_updator.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#
1717
import torch
1818
import torch.distributed as dist
19+
import vllm.envs as envs
1920
from multiprocessing import Queue, Manager
2021

2122
from vllm.logger import logger
@@ -33,11 +34,17 @@ def set_adaptor(self, adaptor):
3334
self.num_moe_layers = self.adaptor.num_moe_layers
3435

3536
def init_eplb(self, expert_map_path):
36-
37+
self.num_expert_load_gather = 10
3738
self.redundant_enable = (expert_map_path != None)
3839
self.num_iterations: torch.int64 = 130
3940
self.expert_map_path = expert_map_path
4041

42+
try:
43+
if not envs.VLLM_ALLOW_EXPERT_LOAD_COLLECTING:
44+
self.num_expert_load_gather = self.num_iterations
45+
except Exception as e:
46+
self.num_expert_load_gather = self.num_iterations
47+
4148
self.weight_update_counter = 0
4249
self.expert_map_initialized = False
4350
self.update_in_flight = False
@@ -80,10 +87,9 @@ def init_eplb(self, expert_map_path):
8087

8188
def get_update_iteration(self):
8289
self.cur_iterations = self.cur_iterations + 1
83-
if not self.gate_eplb:
84-
return self.cur_iterations % self.num_iterations == 0
85-
else:
86-
return self.cur_iterations == self.num_iterations
90+
load_gather_iteration = self.cur_iterations % self.num_expert_load_gather == 0 if not self.gate_eplb else self.cur_iterations == self.num_iterations
91+
upate_iteration = self.cur_iterations % self.num_iterations == 0 if not self.gate_eplb else self.cur_iterations == self.num_iterations
92+
return load_gather_iteration, upate_iteration
8793

8894
def get_init_expert_map(self):
8995
try:
@@ -125,12 +131,15 @@ def forward_before(self):
125131
self.eplb_loader.asyn_expert_weight_transfer(self.reqs)
126132

127133
def forward_end(self):
128-
if not self.update_in_flight and self.get_update_iteration():
129-
moe_load = self.compute_and_set_moe_load()
130-
self.wakeup_eplb_worker()
131-
self.update_in_flight = True
132-
self.wait_worker_iterations = 0
133-
self.weight_loading = False
134+
if not self.update_in_flight:
135+
load_gather_iteration, update_iteration = self.get_update_iteration()
136+
if load_gather_iteration:
137+
self.moe_load = self.compute_and_set_moe_load()
138+
if update_iteration:
139+
self.wakeup_eplb_worker()
140+
self.update_in_flight = True
141+
self.wait_worker_iterations = 0
142+
self.weight_loading = False
134143

135144
if self.update_in_flight:
136145
self.wait_worker_iterations = self.wait_worker_iterations + 1
@@ -220,9 +229,27 @@ def unpack_update_batch(self, packed_update_info):
220229
return recovered
221230

222231
def get_expert_load(self) -> str:
223-
"""todo 确认moe_load的值是什么类型"""
224-
# return '{"a":"b"}' # mock
225-
return self.shared_dict['moe_load']
232+
233+
# todo wjh 给到返回值
234+
# return self.shared_dict['moe_load']
235+
# mock json_str
236+
experts_load = ('{\"expert_load\":['
237+
'{\"ip\":\"141.xxx.xxx.181\",'
238+
'\"node_0\":'
239+
'{\"card_0\":'
240+
'[{\"layer_4\":{\"expert_0\":3,\"expert_2\":1}},{\"layer_5\":{\"expert_0\":3,\"expert_2\":1}}],'
241+
'\"card_1\":[{\"layer_4\":{\"expert_1\":3,\"expert_3\":1},\"layer_5\":{\"expert_0\":3,\"'
242+
'expert_2\":1}}]}},{\"ip\":\"141.xxx.xxx.177\",\"node_0\":{\"card_0\":[{\"layer_4\":'
243+
'{\"expert_0\":3,\"expert_2\":1}},{\"layer_5\":{\"expert_0\":3,\"expert_2\":1}}],'
244+
'\"card_1\":[{\"layer_4\":{\"expert_1\":3,\"expert_3\":1}}]}}]}')
245+
return experts_load
246+
247+
def update_expert_load_statistical_period(self, num_expert_load_gather: int, num_iterations: int):
248+
logger.info(f" start update {self.num_expert_load_gather=}, {self.num_iterations}...")
249+
self.num_expert_load_gather = num_expert_load_gather
250+
self.num_iterations = num_iterations
251+
logger.info(f" update {self.num_expert_load_gather=}, {self.num_iterations} success...")
252+
226253

227254
def shutdown(self):
228255
"""

vllm_ascend/worker/model_runner_v1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1589,6 +1589,9 @@ def profile_run(self) -> None:
15891589
def do_get_expert_load(self) -> str:
15901590
return self.eplb_updator.get_expert_load()
15911591

1592+
def do_update_expert_load_statistical_period(self, num_expert_load_gather: int, num_iterations: int):
1593+
return self.eplb_updator.update_expert_load_statistical_period(num_expert_load_gather, num_iterations)
1594+
15921595
def eplb_warmup(self):
15931596
#EPLB
15941597
if self.dynamic_eplb and not self.is_eplb_warmuped:

vllm_ascend/worker/worker_v1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,9 @@ def get_expert_load(self) -> str:
214214
moe_load = self.model_runner.do_get_expert_load()
215215
return moe_load
216216

217+
def update_expert_load_statistical_period(self, num_expert_load_gather: int, num_iterations: int):
218+
self.model_runner.do_update_expert_load_statistical_period(num_expert_load_gather, num_iterations)
219+
217220
def get_model(self) -> nn.Module:
218221
return self.model_runner.get_model()
219222

0 commit comments

Comments
 (0)