|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +import pytest |
| 5 | +import torch |
| 6 | + |
| 7 | +from vllm.distributed.eplb.rebalance_algo import rebalance_experts |
| 8 | + |
| 9 | + |
| 10 | +def test_basic_rebalance(): |
| 11 | + """Test basic rebalancing functionality""" |
| 12 | + # Example from https://github.com/deepseek-ai/eplb |
| 13 | + weight = torch.tensor([ |
| 14 | + [90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86], |
| 15 | + [20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27], |
| 16 | + ]) |
| 17 | + |
| 18 | + num_layers = weight.shape[0] |
| 19 | + num_replicas = 16 |
| 20 | + num_groups = 4 |
| 21 | + num_nodes = 2 |
| 22 | + num_gpus = 8 |
| 23 | + |
| 24 | + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, |
| 25 | + num_groups, num_nodes, |
| 26 | + num_gpus) |
| 27 | + |
| 28 | + # Verify output shapes |
| 29 | + assert phy2log.shape == ( |
| 30 | + 2, |
| 31 | + 16, |
| 32 | + ), f"Expected `phy2log` shape (2, 16), got {phy2log.shape}" |
| 33 | + assert (log2phy.shape[0] == 2 |
| 34 | + ), f"Expected `log2phy` first dimension 2, got {log2phy.shape[0]}" |
| 35 | + assert ( |
| 36 | + log2phy.shape[1] == 12 |
| 37 | + ), f"Expected `log2phy` second dimension 12, got {log2phy.shape[1]}" |
| 38 | + assert logcnt.shape == ( |
| 39 | + 2, |
| 40 | + 12, |
| 41 | + ), f"Expected `logcnt` shape (2, 12), got {logcnt.shape}" |
| 42 | + |
| 43 | + # Verify physical to logical expert mapping range is correct |
| 44 | + assert torch.all(phy2log >= 0) and torch.all( |
| 45 | + phy2log < 12), "Physical to logical mapping should be in range [0, 12)" |
| 46 | + |
| 47 | + # Verify expert count reasonableness |
| 48 | + assert torch.all( |
| 49 | + logcnt >= 1), "Each logical expert should have at least 1 replica" |
| 50 | + assert ( |
| 51 | + torch.sum(logcnt, dim=1).sum() == num_replicas * |
| 52 | + num_layers), f"Total replicas should be {num_replicas * num_layers}" |
| 53 | + |
| 54 | + # Verify expected output |
| 55 | + expected_phy2log = torch.tensor([ |
| 56 | + [5, 6, 5, 7, 8, 4, 3, 4, 10, 9, 10, 2, 0, 1, 11, 1], |
| 57 | + [7, 10, 6, 8, 6, 11, 8, 9, 2, 4, 5, 1, 5, 0, 3, 1], |
| 58 | + ]) |
| 59 | + assert torch.all(phy2log == expected_phy2log) |
| 60 | + |
| 61 | + expected_logcnt = torch.tensor([[1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1], |
| 62 | + [1, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1]]) |
| 63 | + assert torch.all(logcnt == expected_logcnt) |
| 64 | + |
| 65 | + |
| 66 | +def test_single_gpu_case(): |
| 67 | + """Test single GPU case""" |
| 68 | + weight = torch.tensor([[10, 20, 30, 40]]) |
| 69 | + num_replicas = 4 |
| 70 | + num_groups = 1 |
| 71 | + num_nodes = 1 |
| 72 | + num_gpus = 1 |
| 73 | + |
| 74 | + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, |
| 75 | + num_groups, num_nodes, |
| 76 | + num_gpus) |
| 77 | + |
| 78 | + # Verify shapes |
| 79 | + assert phy2log.shape == (1, 4) |
| 80 | + assert log2phy.shape[0] == 1 |
| 81 | + assert log2phy.shape[1] == 4 |
| 82 | + assert logcnt.shape == (1, 4) |
| 83 | + |
| 84 | + # Verify all logical experts are mapped |
| 85 | + assert set(phy2log[0].tolist()) == {0, 1, 2, 3} |
| 86 | + |
| 87 | + |
| 88 | +def test_equal_weights(): |
| 89 | + """Test case with equal weights""" |
| 90 | + weight = torch.tensor([[50, 50, 50, 50, 50, 50, 50, 50]]) |
| 91 | + num_replicas = 8 |
| 92 | + num_groups = 2 |
| 93 | + num_nodes = 2 |
| 94 | + num_gpus = 4 |
| 95 | + |
| 96 | + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, |
| 97 | + num_groups, num_nodes, |
| 98 | + num_gpus) |
| 99 | + |
| 100 | + # Verify shapes |
| 101 | + assert phy2log.shape == (1, 8) |
| 102 | + assert logcnt.shape == (1, 8) |
| 103 | + |
| 104 | + # With equal weights, each expert should have exactly one replica |
| 105 | + assert torch.all( |
| 106 | + logcnt == 1 |
| 107 | + ), "With equal weights and no replication, " \ |
| 108 | + "each expert should have exactly 1 replica" |
| 109 | + |
| 110 | + |
| 111 | +def test_extreme_weight_imbalance(): |
| 112 | + """Test extreme weight imbalance case""" |
| 113 | + weight = torch.tensor([[1000, 1, 1, 1, 1, 1, 1, 1]]) |
| 114 | + num_replicas = 12 |
| 115 | + num_groups = 2 |
| 116 | + num_nodes = 2 |
| 117 | + num_gpus = 4 |
| 118 | + |
| 119 | + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, |
| 120 | + num_groups, num_nodes, |
| 121 | + num_gpus) |
| 122 | + |
| 123 | + # Verify shapes |
| 124 | + assert phy2log.shape == (1, 12) |
| 125 | + assert logcnt.shape == (1, 8) |
| 126 | + |
| 127 | + # Expert with highest weight (index 0) should have more replicas |
| 128 | + assert ( |
| 129 | + logcnt[0, 0] |
| 130 | + > logcnt[0, 1]), "Expert with highest weight should have more replicas" |
| 131 | + |
| 132 | + |
| 133 | +def test_multiple_layers(): |
| 134 | + """Test multiple layers case""" |
| 135 | + weight = torch.tensor([ |
| 136 | + [10, 20, 30, 40, 50, 60], # First layer |
| 137 | + [60, 50, 40, 30, 20, 10], # Second layer (opposite weight pattern) |
| 138 | + [25, 25, 25, 25, 25, 25], # Third layer (equal weights) |
| 139 | + ]) |
| 140 | + num_replicas = 8 |
| 141 | + num_groups = 2 |
| 142 | + num_nodes = 2 |
| 143 | + num_gpus = 4 |
| 144 | + |
| 145 | + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, |
| 146 | + num_groups, num_nodes, |
| 147 | + num_gpus) |
| 148 | + |
| 149 | + # Verify shapes |
| 150 | + assert phy2log.shape == (3, 8) |
| 151 | + assert logcnt.shape == (3, 6) |
| 152 | + |
| 153 | + # Verify expert allocation is reasonable for each layer |
| 154 | + for layer in range(3): |
| 155 | + assert torch.all(phy2log[layer] >= 0) and torch.all( |
| 156 | + phy2log[layer] < 6 |
| 157 | + ), f"Layer {layer} physical to logical mapping" \ |
| 158 | + "should be in range [0, 6)" |
| 159 | + assert (torch.sum(logcnt[layer]) == num_replicas |
| 160 | + ), f"Layer {layer} total replicas should be {num_replicas}" |
| 161 | + |
| 162 | + |
| 163 | +def test_parameter_validation(): |
| 164 | + """Test parameter validation""" |
| 165 | + weight = torch.tensor([[10, 20, 30, 40]]) |
| 166 | + |
| 167 | + # Test non-divisible case - this should handle normally without throwing |
| 168 | + # errors because the function will fall back to global load balancing |
| 169 | + # strategy |
| 170 | + phy2log, log2phy, logcnt = rebalance_experts(weight, 8, 3, 2, 4) |
| 171 | + assert phy2log.shape == (1, 8) |
| 172 | + assert logcnt.shape == (1, 4) |
| 173 | + |
| 174 | + # Test cases that will actually cause errors: |
| 175 | + # num_physical_experts not divisible by num_gpus |
| 176 | + with pytest.raises(AssertionError): |
| 177 | + rebalance_experts(weight, 7, 2, 2, 4) # 7 not divisible by 4 |
| 178 | + |
| 179 | + |
| 180 | +def test_small_scale_hierarchical(): |
| 181 | + """Test small-scale hierarchical load balancing""" |
| 182 | + weight = torch.tensor([ |
| 183 | + [100, 50, 200, 75, 150, 25, 300, 80], # 8 experts |
| 184 | + ]) |
| 185 | + num_replicas = 12 |
| 186 | + num_groups = 4 # 4 groups, 2 experts each |
| 187 | + num_nodes = 2 # 2 nodes |
| 188 | + num_gpus = 4 # 4 GPUs |
| 189 | + |
| 190 | + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, |
| 191 | + num_groups, num_nodes, |
| 192 | + num_gpus) |
| 193 | + |
| 194 | + # Verify basic constraints |
| 195 | + assert phy2log.shape == (1, 12) |
| 196 | + assert logcnt.shape == (1, 8) |
| 197 | + assert torch.sum(logcnt) == num_replicas |
| 198 | + assert torch.all(logcnt >= 1) |
| 199 | + |
| 200 | + # Expert with highest weight should have more replicas |
| 201 | + max_weight_expert = torch.argmax(weight[0]) |
| 202 | + assert (logcnt[0, max_weight_expert] |
| 203 | + >= 2), "Highest weight expert should have multiple replicas" |
| 204 | + |
| 205 | + |
| 206 | +def test_global_load_balance_fallback(): |
| 207 | + """Test global load balancing fallback case""" |
| 208 | + # When num_groups % num_nodes != 0, should fall back to global load |
| 209 | + # balancing |
| 210 | + weight = torch.tensor([[10, 20, 30, 40, 50, 60]]) |
| 211 | + num_replicas = 8 |
| 212 | + num_groups = 3 # Cannot be divided evenly by num_nodes=2 |
| 213 | + num_nodes = 2 |
| 214 | + num_gpus = 4 |
| 215 | + |
| 216 | + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, |
| 217 | + num_groups, num_nodes, |
| 218 | + num_gpus) |
| 219 | + |
| 220 | + # Should work normally, just using global load balancing strategy |
| 221 | + assert phy2log.shape == (1, 8) |
| 222 | + assert logcnt.shape == (1, 6) |
| 223 | + assert torch.sum(logcnt) == num_replicas |
| 224 | + |
| 225 | + |
| 226 | +@pytest.mark.parametrize("device", ["cpu", "cuda"]) |
| 227 | +def test_device_compatibility(device): |
| 228 | + """Test device compatibility""" |
| 229 | + if device == "cuda" and not torch.cuda.is_available(): |
| 230 | + pytest.skip("CUDA not available") |
| 231 | + |
| 232 | + weight = torch.tensor([[10, 20, 30, 40]], device=device) |
| 233 | + num_replicas = 6 |
| 234 | + num_groups = 2 |
| 235 | + num_nodes = 1 |
| 236 | + num_gpus = 2 |
| 237 | + |
| 238 | + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, |
| 239 | + num_groups, num_nodes, |
| 240 | + num_gpus) |
| 241 | + |
| 242 | + # Function will convert to CPU internally, but should handle different |
| 243 | + # device inputs normally |
| 244 | + assert phy2log.shape == (1, 6) |
| 245 | + assert logcnt.shape == (1, 4) |
| 246 | + |
| 247 | + |
| 248 | +def test_additional_cases(): |
| 249 | + """Test more edge cases and different parameter combinations""" |
| 250 | + |
| 251 | + # Test case 1: Large-scale distributed setup |
| 252 | + weight1 = torch.tensor( |
| 253 | + [[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]]) |
| 254 | + phy2log1, log2phy1, logcnt1 = rebalance_experts(weight1, 24, 8, 4, 8) |
| 255 | + |
| 256 | + assert phy2log1.shape == (1, 24) |
| 257 | + assert logcnt1.shape == (1, 16) |
| 258 | + assert torch.sum(logcnt1) == 24 |
| 259 | + |
| 260 | + # Test case 2: Different weight distributions |
| 261 | + weight2 = torch.tensor([ |
| 262 | + [200, 150, 100, 50, 25, 12], # Decreasing weights |
| 263 | + [12, 25, 50, 100, 150, 200], # Increasing weights |
| 264 | + ]) |
| 265 | + phy2log2, log2phy2, logcnt2 = rebalance_experts(weight2, 10, 3, 1, 2) |
| 266 | + |
| 267 | + assert phy2log2.shape == (2, 10) |
| 268 | + assert logcnt2.shape == (2, 6) |
| 269 | + |
| 270 | + # Verify high-weight experts have more replicas |
| 271 | + for layer in range(2): |
| 272 | + max_weight_idx = torch.argmax(weight2[layer]) |
| 273 | + assert logcnt2[layer, max_weight_idx] >= 2 |
| 274 | + |
| 275 | + |
| 276 | +if __name__ == "__main__": |
| 277 | + weight = torch.tensor([ |
| 278 | + [90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86], |
| 279 | + [20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27], |
| 280 | + ]) |
| 281 | + |
| 282 | + num_replicas = 16 |
| 283 | + num_groups = 4 |
| 284 | + num_nodes = 2 |
| 285 | + num_gpus = 8 |
| 286 | + |
| 287 | + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, |
| 288 | + num_groups, num_nodes, |
| 289 | + num_gpus) |
| 290 | + print(phy2log) |
| 291 | + |
| 292 | + test_basic_rebalance() |
0 commit comments