|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +""" |
| 3 | +Expert parallelism load balancer (EPLB) for vLLM. |
| 4 | +The rearrangement algorithm is adapted from |
| 5 | +[DeepSeek EPLB](https://github.com/deepseek-ai/eplb). |
| 6 | +""" |
| 7 | +from typing import Tuple |
| 8 | + |
| 9 | +import torch |
| 10 | + |
| 11 | + |
| 12 | +def balanced_packing(weight: torch.Tensor, |
| 13 | + num_packs: int) -> Tuple[torch.Tensor, torch.Tensor]: |
| 14 | + """ |
| 15 | + Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs |
| 16 | + are as balanced as possible. |
| 17 | +
|
| 18 | + Parameters: |
| 19 | + weight: [X, n], the weight of each item |
| 20 | + num_packs: number of packs |
| 21 | + |
| 22 | + Returns: |
| 23 | + pack_index: [X, n], the pack index of each item |
| 24 | + rank_in_pack: [X, n], the rank of the item in the pack |
| 25 | + """ |
| 26 | + num_layers, num_groups = weight.shape |
| 27 | + assert num_groups % num_packs == 0 |
| 28 | + groups_per_pack = num_groups // num_packs |
| 29 | + |
| 30 | + if groups_per_pack == 1: |
| 31 | + pack_index = torch.arange(weight.size(-1), |
| 32 | + dtype=torch.int64, |
| 33 | + device=weight.device).expand(weight.shape) |
| 34 | + rank_in_pack = torch.zeros_like(weight, dtype=torch.int64) |
| 35 | + return pack_index, rank_in_pack |
| 36 | + |
| 37 | + indices = weight.float().sort(-1, descending=True).indices.cpu() |
| 38 | + pack_index = torch.full_like(weight, |
| 39 | + fill_value=-1, |
| 40 | + dtype=torch.int64, |
| 41 | + device='cpu') |
| 42 | + rank_in_pack = torch.full_like(pack_index, fill_value=-1) |
| 43 | + for i in range(num_layers): |
| 44 | + pack_weights = [0] * num_packs |
| 45 | + pack_items = [0] * num_packs |
| 46 | + for group in indices[i]: |
| 47 | + pack = min( |
| 48 | + (i |
| 49 | + for i in range(num_packs) if pack_items[i] < groups_per_pack), |
| 50 | + key=pack_weights.__getitem__) |
| 51 | + assert pack_items[pack] < groups_per_pack |
| 52 | + pack_index[i, group] = pack |
| 53 | + rank_in_pack[i, group] = pack_items[pack] |
| 54 | + pack_weights[pack] += weight[i, group] |
| 55 | + pack_items[pack] += 1 |
| 56 | + return pack_index, rank_in_pack |
| 57 | + |
| 58 | + |
| 59 | +def replicate_experts( |
| 60 | + weight: torch.Tensor, |
| 61 | + num_phy: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| 62 | + """ |
| 63 | + Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized. |
| 64 | +
|
| 65 | + Parameters: |
| 66 | + weight: [X, num_log] |
| 67 | + num_phy: total number of experts after replication |
| 68 | + |
| 69 | + Returns: |
| 70 | + phy2log: [X, num_phy], logical expert id of each physical expert |
| 71 | + rank: [X, num_phy], the replica rank |
| 72 | + logcnt: [X, num_log], number of replicas for each logical expert |
| 73 | + """ |
| 74 | + n, num_log = weight.shape |
| 75 | + num_redundant = num_phy - num_log |
| 76 | + assert num_redundant >= 0 |
| 77 | + device = weight.device |
| 78 | + phy2log = torch.arange(num_phy, dtype=torch.int64, |
| 79 | + device=device).repeat(n, 1) |
| 80 | + rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) |
| 81 | + logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) |
| 82 | + arangen = torch.arange(n, dtype=torch.int64, device=device) |
| 83 | + for i in range(num_log, num_phy): |
| 84 | + redundant_indices = (weight / logcnt).max(dim=-1).indices |
| 85 | + phy2log[:, i] = redundant_indices |
| 86 | + rank[:, i] = logcnt[arangen, redundant_indices] |
| 87 | + logcnt[arangen, redundant_indices] += 1 |
| 88 | + return phy2log, rank, logcnt |
| 89 | + |
| 90 | + |
| 91 | +def rebalance_experts_hierarchical(weight: torch.Tensor, |
| 92 | + num_physical_experts: int, num_groups: int, |
| 93 | + num_nodes: int, num_gpus: int): |
| 94 | + """ |
| 95 | + Parameters: |
| 96 | + weight: [num_moe_layers, num_logical_experts] |
| 97 | + num_physical_experts: number of physical experts after replication |
| 98 | + num_groups: number of expert groups |
| 99 | + num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster |
| 100 | + num_gpus: number of GPUs, must be a multiple of `num_nodes` |
| 101 | +
|
| 102 | + Returns: |
| 103 | + physical_to_logical_map: [num_moe_layers, num_physical_experts] |
| 104 | + logical_to_physical_map: [num_moe_layers, num_logical_experts, X] |
| 105 | + logical_count: [num_moe_layers, num_logical_experts] |
| 106 | + """ |
| 107 | + num_layers, num_logical_experts = weight.shape |
| 108 | + assert num_logical_experts % num_groups == 0 |
| 109 | + group_size = num_logical_experts // num_groups |
| 110 | + assert num_groups % num_nodes == 0 |
| 111 | + groups_per_node = num_groups // num_nodes |
| 112 | + assert num_gpus % num_nodes == 0 |
| 113 | + assert num_physical_experts % num_gpus == 0 |
| 114 | + phy_experts_per_gpu = num_physical_experts // num_gpus |
| 115 | + |
| 116 | + def inverse(perm: torch.Tensor) -> torch.Tensor: |
| 117 | + inv = torch.empty_like(perm) |
| 118 | + inv.scatter_( |
| 119 | + 1, perm, |
| 120 | + torch.arange(perm.size(1), dtype=torch.int64, |
| 121 | + device=perm.device).expand(perm.shape)) |
| 122 | + return inv |
| 123 | + |
| 124 | + # Step 1: pack groups to nodes |
| 125 | + tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) |
| 126 | + group_pack_index, group_rank_in_pack = balanced_packing( |
| 127 | + tokens_per_group, num_nodes) |
| 128 | + log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) * |
| 129 | + group_size).unsqueeze(-1) + |
| 130 | + torch.arange(group_size, |
| 131 | + dtype=torch.int64, |
| 132 | + device=group_pack_index.device)).flatten(-2) |
| 133 | + mlog2log = inverse(log2mlog) |
| 134 | + |
| 135 | + # Step 2: construct redundant experts within nodes |
| 136 | + # [num_layers * num_nodes, num_logical_experts // num_nodes] |
| 137 | + tokens_per_mlog = weight.gather(-1, mlog2log).view( |
| 138 | + -1, num_logical_experts // num_nodes) |
| 139 | + phy2mlog, phyrank, mlogcnt = replicate_experts( |
| 140 | + tokens_per_mlog, num_physical_experts // num_nodes) |
| 141 | + |
| 142 | + # Step 3: pack physical_experts to GPUs |
| 143 | + # [num_layers * num_nodes, num_physical_experts // num_nodes] |
| 144 | + tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) |
| 145 | + pack_index, rank_in_pack = balanced_packing(tokens_per_phy, |
| 146 | + num_gpus // num_nodes) |
| 147 | + phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack |
| 148 | + pphy2phy = inverse(phy2pphy) |
| 149 | + |
| 150 | + pphy2mlog = phy2mlog.gather( |
| 151 | + -1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes] |
| 152 | + pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + torch.arange( |
| 153 | + 0, |
| 154 | + num_logical_experts, |
| 155 | + num_logical_experts // num_nodes, |
| 156 | + device=group_pack_index.device).view(1, -1, 1)).flatten(-2) |
| 157 | + pphy2log = mlog2log.gather(-1, pphy2mlog) |
| 158 | + pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) |
| 159 | + logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) |
| 160 | + return pphy2log, pphyrank, logcnt |
| 161 | + |
| 162 | + |
| 163 | +def rebalance_experts( |
| 164 | + weight: torch.Tensor, num_replicas: int, num_groups: int, |
| 165 | + num_nodes: int, |
| 166 | + num_gpus: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| 167 | + """ |
| 168 | + Entry point for expert-parallelism load balancer. |
| 169 | +
|
| 170 | + Parameters: |
| 171 | + weight: [layers, num_logical_experts], the load statistics for all logical experts |
| 172 | + num_replicas: number of physical experts, must be a multiple of `num_gpus` |
| 173 | + num_groups: number of expert groups |
| 174 | + num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster |
| 175 | + num_gpus: number of GPUs, must be a multiple of `num_nodes` |
| 176 | +
|
| 177 | + Returns: |
| 178 | + physical_to_logical_map: [layers, num_replicas], the expert index of each replica |
| 179 | + logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert |
| 180 | + expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert |
| 181 | + """ |
| 182 | + num_layers, num_logical_experts = weight.shape |
| 183 | + weight = weight.float().cpu() |
| 184 | + if num_groups % num_nodes == 0: |
| 185 | + # use hierarchical load-balance policy |
| 186 | + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( |
| 187 | + weight, num_replicas, num_groups, num_nodes, num_gpus) |
| 188 | + else: |
| 189 | + # use global load-balance policy |
| 190 | + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( |
| 191 | + weight, num_replicas, 1, 1, num_gpus) |
| 192 | + maxlogcnt = logcnt.max().item() |
| 193 | + log2phy: torch.Tensor = torch.full( |
| 194 | + (num_layers, num_logical_experts, maxlogcnt), |
| 195 | + -1, |
| 196 | + dtype=torch.int64, |
| 197 | + device=logcnt.device) |
| 198 | + log2phy.view(num_layers, -1).scatter_( |
| 199 | + -1, phy2log * maxlogcnt + phyrank, |
| 200 | + torch.arange(num_replicas, dtype=torch.int64, |
| 201 | + device=log2phy.device).expand(num_layers, -1)) |
| 202 | + return phy2log, log2phy, logcnt |
| 203 | + |
| 204 | + |
| 205 | +__all__ = ['rebalance_experts'] |
0 commit comments