Skip to content

Commit bfa07cf

Browse files
authored
Merge pull request #104 from raindaywhu/new_dev_main_cy
fix SwiftBalancer eplb algo
2 parents c57611c + 1f0b980 commit bfa07cf

File tree

1 file changed

+56
-49
lines changed

1 file changed

+56
-49
lines changed

vllm_ascend/eplb/core/policy/dynamic_ep_v2.py

Lines changed: 56 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,30 @@ def non_redundant_expert_information(origin_deployment, updated_weights, num_rad
383383
return device_assignments, device_weights, device_loads, device_counts
384384

385385
@staticmethod
386-
def distribute_redun_experts(device_assignments, device_weights, device_loads, device_counts, redundant_expert_list,
386+
def recomputing_weight(layer_workloads, device_assignments, device_weights, device_loads):
387+
# 统计专家出现次数
388+
num_all_experts = [0] * len(layer_workloads)
389+
num_devices = len(device_assignments)
390+
for device_id in range(num_devices):
391+
num_expert_per_npu = len(device_assignments[device_id])
392+
for idx in range(num_expert_per_npu):
393+
num_all_experts[idx] += device_assignments[device_id][idx]
394+
395+
for device_id in range(num_devices):
396+
num_expert_per_npu = len(device_weights[device_id])
397+
total_weight = 0.0
398+
for idx in range(num_expert_per_npu):
399+
expert_id = device_assignments[device_id][idx]
400+
if num_all_experts[expert_id] == 0:
401+
print("Error: Division by zero")
402+
device_weights[device_id][idx] = layer_workloads[expert_id] / num_all_experts[expert_id]
403+
total_weight += device_weights[device_id][idx]
404+
device_loads[device_id] = total_weight
405+
406+
return device_weights, device_loads
407+
408+
@staticmethod
409+
def distribute_redun_experts(self, layer_workloads, device_assignments, device_weights, device_loads, device_counts, redundant_expert_list,
387410
items_per_device, expert_form_device, num_experts):
388411

389412
num_devices = len(device_assignments)
@@ -411,18 +434,16 @@ def distribute_redun_experts(device_assignments, device_weights, device_loads, d
411434
communication_box_index = expert_form_device[expert_id]
412435
com_between_devices[candidate][communication_box_index] = expert_id
413436
# 极端情况下存在冗余专家没装箱 导致箱子有空位 随机填入专家 待优化
437+
flag = False
414438
for dev_id in range(num_devices):
415439
# 检查容量限制
416440
if device_counts[dev_id] < items_per_device:
417441
# 遍历合适的专家
418442
for expert_id in range(num_experts):
419443
if expert_id not in device_assignments[dev_id]:
420-
# 找到对应权重
421-
weight = 0
422-
for i in range(num_devices):
423-
for j in range(len(device_assignments[i])):
424-
if expert_id == device_assignments[i][j]:
425-
weight = device_weights[i][j]
444+
flag = True
445+
# 随机初始化一个权重
446+
weight = 0.0
426447
# 和该专家相关的卡权重发生变化 待修改
427448
device_assignments[dev_id].insert(0, expert_id)
428449
device_weights[dev_id].insert(0, weight)
@@ -432,12 +453,14 @@ def distribute_redun_experts(device_assignments, device_weights, device_loads, d
432453
communication_box_index = expert_form_device[expert_id]
433454
com_between_devices[dev_id][communication_box_index] = expert_id
434455
break
435-
#todo 重新生成权重
456+
457+
if flag:
458+
device_weights, device_loads = self.recomputing_weight(layer_workloads, device_assignments, device_weights, device_loads)
436459

437460
return device_assignments, device_weights, device_loads, device_counts, com_between_devices
438461

439462
@staticmethod
440-
def redundancy_again(self, origin_weights, num_redundant_experts, origin_deployment, expert_form_device, num_node,
463+
def redundancy_again(self, layer_workloads, origin_weights, num_redundant_experts, origin_deployment, expert_form_device, num_node,
441464
is_node_redundant):
442465

443466
# 每张卡上专家数量
@@ -461,6 +484,8 @@ def redundancy_again(self, origin_weights, num_redundant_experts, origin_deploym
461484

462485
# 新计算的冗余专家进行分配
463486
device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.distribute_redun_experts(
487+
self,
488+
layer_workloads,
464489
device_assignments,
465490
device_weights,
466491
device_loads,
@@ -554,6 +579,7 @@ def redundant_expert_deployment(self, layer_workloads, original_deployment, expe
554579

555580
cur_device_assignments, cur_device_weights, cur_device_loads, cur_device_counts, cur_com_between_devices = self.redundancy_again(
556581
self,
582+
layer_workloads,
557583
cur_node_weights,
558584
per_node_redun_expert_num,
559585
cur_original_deployment,
@@ -569,6 +595,7 @@ def redundant_expert_deployment(self, layer_workloads, original_deployment, expe
569595
else:
570596
device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.redundancy_again(
571597
self,
598+
layer_workloads,
572599
weights,
573600
redundancy_expert_num,
574601
original_deployment,
@@ -583,7 +610,7 @@ def redundant_expert_deployment(self, layer_workloads, original_deployment, expe
583610

584611
@staticmethod
585612
def two_device_exchange_experts(cur_device_result, exchange_device_result, cur_exchanged_expert_id,
586-
next_exchanged_expert_id, ave_workload, increment, num_redundancy_expert):
613+
next_exchanged_expert_id, ave_workload, increment, num_redundancy_expert, cur_org_placement, next_org_placement):
587614

588615
cur_device_weight = cur_device_result['expert_weights']
589616
next_device_weight = exchange_device_result['expert_weights']
@@ -609,7 +636,8 @@ def two_device_exchange_experts(cur_device_result, exchange_device_result, cur_e
609636
continue
610637
# 交换专家限制卡内专家不同
611638
change_flag = True
612-
if cur_device_expert_id[index] in next_device_expert_id or next_device_expert_id[next_index] in cur_device_expert_id:
639+
if ((cur_device_expert_id[index] in next_device_expert_id or next_device_expert_id[next_index] in cur_device_expert_id) or
640+
(cur_org_placement[0] == next_device_expert_id[next_index] or next_org_placement[0] == cur_device_expert_id[index])):
613641
change_flag = False
614642
# 选择的专家不能是参与过交换的
615643
if (cur_device_expert_id[index] not in cur_exchanged_expert_id) and (
@@ -627,8 +655,7 @@ def two_device_exchange_experts(cur_device_result, exchange_device_result, cur_e
627655

628656
@staticmethod
629657
def expert_exchange_between_devices(self, ave_workload, increment, cur_layer_result, com_between_devices, num_redundancy_expert,
630-
node_idx=0,
631-
per_node_device_num=0, is_node_redundant=False):
658+
org_placement_table, node_idx=0, per_node_device_num=0, is_node_redundant=False):
632659

633660
if is_node_redundant:
634661
# 拿出当前节点内设备的信息
@@ -677,7 +704,9 @@ def expert_exchange_between_devices(self, ave_workload, increment, cur_layer_res
677704
next_exchanged_expert_id,
678705
ave_workload,
679706
increment,
680-
num_redundancy_expert)
707+
num_redundancy_expert,
708+
org_placement_table[max_weight_device_id],
709+
org_placement_table[min_weight_device_id])
681710

682711
# 有符合条件的专家进行交换
683712
if cur_exchange_index != -1:
@@ -700,7 +729,7 @@ def expert_exchange_between_devices(self, ave_workload, increment, cur_layer_res
700729

701730
@staticmethod
702731
def exchange_experts(self, layer_result, layer_com_between_devices, num_nodes, device_num, is_node_redundant,
703-
ave_workload, increment, num_redundancy_expert):
732+
ave_workload, increment, num_redundancy_expert, org_placement_table):
704733

705734
global_deployment = []
706735

@@ -709,9 +738,9 @@ def exchange_experts(self, layer_result, layer_com_between_devices, num_nodes, d
709738
for node_idx in range(num_nodes):
710739
self.expert_exchange_between_devices(self, ave_workload, increment, layer_result,
711740
layer_com_between_devices, num_redundancy_expert,
712-
node_idx, per_node_device_num, is_node_redundant)
741+
org_placement_table, node_idx, per_node_device_num, is_node_redundant)
713742
else:
714-
self.expert_exchange_between_devices(self, ave_workload, increment, layer_result, layer_com_between_devices, num_redundancy_expert)
743+
self.expert_exchange_between_devices(self, ave_workload, increment, layer_result, layer_com_between_devices, num_redundancy_expert, org_placement_table)
715744

716745
max_workload = 0
717746
for box in layer_result:
@@ -734,14 +763,15 @@ def count_elements(self, lst):
734763
return count
735764

736765
def rebalance_experts(self, current_expert_table, expert_workload):
737-
766+
# 输入:当前专家部署信息和对应的负载信息,形状为layer_num, num_npus, experts_per_npu
738767
info = DynamicTable()
739-
info.workload_table = np.array(expert_workload)
740-
info.placement_table = np.array(current_expert_table)
768+
info.workload_table = expert_workload.numpy()
769+
info.placement_table = current_expert_table.numpy()
741770
layer_num, num_npus, experts_per_npu = info.workload_table.shape
742771
expert_ids, counts = np.unique(info.placement_table[0], return_counts=True)
743772
num_redundancy_expert = self.get_redundant_num(num_npus, counts)
744773
num_original_expert = len(expert_ids)
774+
# 负载信息转化为 58 * 256
745775
layer_workloads = self.add_redundant(info.placement_table, info.workload_table, num_original_expert)
746776
max_heat_per_layer_before = self.calculate_max_heat_per_layer(info.workload_table, layer_num)
747777
npu_heat_all_origin = sum(max_heat_per_layer_before)
@@ -764,50 +794,31 @@ def rebalance_experts(self, current_expert_table, expert_workload):
764794
# 每个卡部署的专家数量 一个冗余专家
765795
global_deployment = [[[] for _ in range(num_npus)] for _ in range(layer_num)]
766796
# 统计更换数据集后的初始58层不均衡度
767-
layer_initial_imbalance = self.calculate_initial_imbalance(current_expert_table, layer_workloads)
797+
layer_initial_imbalance = self.calculate_initial_imbalance(info.placement_table, layer_workloads)
768798
# 遍历获得每一层的放置策略,考虑计算均衡
769799
max_heat_per_layer_after = np.zeros([layer_num])
770800
sum_num = 0
771801
for layer in range(layer_num):
772802
# 不均衡度小于特定阈值不调整
773803
if layer_initial_imbalance[layer] < 1.1:
774-
global_deployment[layer] = current_expert_table[layer]
804+
global_deployment[layer] = info.placement_table[layer]
775805
continue
776806

777807
ave_workload = np.sum(layer_workloads[layer]) / num_npus
778-
for device_id, device in enumerate(current_expert_table[layer]):
808+
for device_id, device in enumerate(info.placement_table[layer]):
779809
for index, expert_id in enumerate(device):
780810
if index != 0:
781811
expert_from_device[layer][expert_id] = device_id
782812

783813
# 调整冗余专家
784814
result, max_workload, com_between_devices = self.redundant_expert_deployment(self, layer_workloads[layer],
785-
current_expert_table[layer],
815+
info.placement_table[layer],
786816
expert_from_device[layer],
787817
num_node, False)
788818
# 交换专家
789819
global_deployment[layer], new_max_workload = self.exchange_experts(self, result, com_between_devices,
790820
num_node, num_npus, False, ave_workload,
791-
0.05, num_redundancy_expert)
792-
793-
# To guarantee there is no expert movement inside a NPU
794-
start_physical_idx = 1 if num_redundancy_expert else 0
795-
for rank in range(num_npus):
796-
physical_expert = start_physical_idx
797-
while physical_expert in range(start_physical_idx, experts_per_npu):
798-
# skip the expert which is moved into this rank
799-
if global_deployment[layer][rank][physical_expert] not in current_expert_table[layer, rank, :]:
800-
physical_expert += 1
801-
continue
802-
803-
if global_deployment[layer][rank][physical_expert] != current_expert_table[layer][rank][physical_expert]:
804-
right_idx = np.where(current_expert_table[layer][rank] == global_deployment[layer][rank][physical_expert])[0][0]
805-
# exchange expert with the expert on the right physical index
806-
tempt = global_deployment[layer][rank][right_idx]
807-
global_deployment[layer][rank][right_idx] = global_deployment[layer][rank][physical_expert]
808-
global_deployment[layer][rank][physical_expert] = tempt
809-
else:
810-
physical_expert += 1
821+
0.05, num_redundancy_expert, info.placement_table[layer])
811822

812823
for device_id in range(num_npus):
813824
com_between_devices[device_id] = {int(key): int(value) for key, value in
@@ -828,8 +839,4 @@ def rebalance_experts(self, current_expert_table, expert_workload):
828839
if npu_heat_all_after < 0.95 * npu_heat_all_origin:
829840
change = 1
830841

831-
return change, per_layer_priority, np.array(global_deployment).tolist()
832-
833-
834-
835-
842+
return change, per_layer_priority, np.array(global_deployment).tolist()

0 commit comments

Comments
 (0)