Skip to content

Commit 4b80612

Browse files
committed
Fix EP simulation: use expert map
Signed-off-by: Alan Chen <zc2610@nyu.edu>
1 parent b502a1f commit 4b80612

File tree

1 file changed

+110
-19
lines changed

1 file changed

+110
-19
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 110 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,43 @@
1212

1313
import ray
1414
import torch
15+
import torch.distributed as dist
1516
from ray.experimental.tqdm_ray import tqdm
1617

1718
from vllm.model_executor.layers.fused_moe.fused_moe import *
1819
from vllm.platforms import current_platform
1920
from vllm.transformers_utils.config import get_config
2021
from vllm.triton_utils import triton
21-
from vllm.utils import FlexibleArgumentParser
22+
from vllm.utils import FlexibleArgumentParser, get_open_port
2223

2324
FP8_DTYPE = current_platform.fp8_dtype()
2425

2526

27+
def build_expert_map(global_num_experts: int,
28+
ep_size: int,
29+
ep_rank: int) -> tuple[int, torch.Tensor]:
30+
"""Build expert map for expert parallel. Returns (local_num_experts, expert_map)."""
31+
# Calculate base number of experts per rank
32+
base_experts = global_num_experts // ep_size
33+
34+
# Create expert map
35+
expert_map = torch.full((global_num_experts,), -1, dtype=torch.int32)
36+
37+
if ep_rank < (ep_size - 1):
38+
# Non-last ranks get base number of experts
39+
local_num_experts = base_experts
40+
start = ep_rank * base_experts
41+
end = start + local_num_experts
42+
expert_map[start:end] = torch.arange(local_num_experts, dtype=torch.int32)
43+
else:
44+
# Last rank gets all remaining experts
45+
start = ep_rank * base_experts
46+
local_num_experts = global_num_experts - start
47+
expert_map[start:] = torch.arange(local_num_experts, dtype=torch.int32)
48+
49+
return local_num_experts, expert_map.cuda()
50+
51+
2652
class BenchmarkConfig(TypedDict):
2753
BLOCK_SIZE_M: int
2854
BLOCK_SIZE_N: int
@@ -47,20 +73,25 @@ def benchmark_config(
4773
use_deep_gemm: bool = False,
4874
enable_expert_parallel: bool = False,
4975
ep_size: int = 1,
76+
ep_rank: int = 0,
5077
) -> float:
5178
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
5279
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
5380

54-
# For expert parallel, only create weights for local experts
81+
# For expert parallel, calculate local and global expert counts
82+
global_num_experts = num_experts
5583
if enable_expert_parallel:
56-
num_experts = num_experts // ep_size
84+
local_num_experts, expert_map = build_expert_map(global_num_experts, ep_size, ep_rank)
85+
else:
86+
local_num_experts = num_experts
87+
expert_map = None
5788

5889
if use_int8_w8a16:
5990
w1 = torch.randint(
6091
-127,
6192
127,
6293
(
63-
num_experts,
94+
local_num_experts,
6495
shard_intermediate_size,
6596
hidden_size,
6697
),
@@ -70,37 +101,38 @@ def benchmark_config(
70101
-127,
71102
127,
72103
(
73-
num_experts,
104+
local_num_experts,
74105
hidden_size,
75106
shard_intermediate_size // 2,
76107
),
77108
dtype=torch.int8,
78109
)
79110
else:
80111
w1 = torch.randn(
81-
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
112+
local_num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
82113
)
83114
w2 = torch.randn(
84-
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
115+
local_num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
85116
)
86-
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
117+
# Gating output uses global number of experts
118+
gating_output = torch.randn(num_iters, num_tokens, global_num_experts, dtype=torch.float32)
87119

88120
w1_scale = None
89121
w2_scale = None
90122
a1_scale = None
91123
a2_scale = None
92124
if use_int8_w8a16:
93125
w1_scale = torch.randn(
94-
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
126+
(local_num_experts, 2 * shard_intermediate_size), dtype=torch.float32
95127
)
96-
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
128+
w2_scale = torch.randn((hidden_size, local_num_experts), dtype=torch.float32)
97129
if use_deep_gemm:
98130
# we use the default block shape for deepgemm
99131
block_quant_shape = [128, 128]
100132
if use_fp8_w8a8:
101133
if block_quant_shape:
102134
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
103-
E = num_experts
135+
E = local_num_experts
104136
N = shard_intermediate_size // 2
105137
K = hidden_size
106138
factor_for_scale = 1e-2
@@ -117,16 +149,16 @@ def benchmark_config(
117149
* factor_for_scale
118150
)
119151
else:
120-
w1_scale = torch.randn(num_experts, dtype=torch.float32)
121-
w2_scale = torch.randn(num_experts, dtype=torch.float32)
152+
w1_scale = torch.randn(local_num_experts, dtype=torch.float32)
153+
w2_scale = torch.randn(local_num_experts, dtype=torch.float32)
122154

123155
a1_scale = torch.randn(1, dtype=torch.float32)
124156
a2_scale = torch.randn(1, dtype=torch.float32)
125157

126158
w1 = w1.to(FP8_DTYPE)
127159
w2 = w2.to(FP8_DTYPE)
128160

129-
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
161+
input_gating = torch.empty(num_tokens, global_num_experts, dtype=torch.float32)
130162

131163
def prepare(i: int):
132164
input_gating.copy_(gating_output[i])
@@ -153,6 +185,8 @@ def run():
153185
a2_scale=a2_scale,
154186
block_shape=block_quant_shape,
155187
allow_deep_gemm=True,
188+
global_num_experts=global_num_experts,
189+
expert_map=expert_map,
156190
)
157191
else:
158192
fused_moe(
@@ -170,6 +204,8 @@ def run():
170204
a1_scale=a1_scale,
171205
a2_scale=a2_scale,
172206
block_shape=block_quant_shape,
207+
global_num_experts=global_num_experts,
208+
expert_map=expert_map,
173209
)
174210

175211
# JIT compilation & warmup
@@ -410,7 +446,41 @@ def __init__(
410446
# Get the device ID to allocate tensors and kernels
411447
# on the respective GPU. This is required for Ray to work
412448
# correctly with multi-GPU tuning on the ROCm platform.
413-
self.device_id = int(ray.get_gpu_ids()[0])
449+
gpu_ids = ray.get_gpu_ids()
450+
if gpu_ids:
451+
self.device_id = int(gpu_ids[0])
452+
else:
453+
self.device_id = 0
454+
self.distributed_initialized = False
455+
456+
def init_distributed(self, master_addr: str, master_port: int) -> None:
457+
"""Initialize torch.distributed for expert parallel."""
458+
if self.distributed_initialized:
459+
return
460+
461+
os.environ['MASTER_ADDR'] = master_addr
462+
os.environ['MASTER_PORT'] = str(master_port)
463+
os.environ['RANK'] = str(self.worker_id)
464+
os.environ['WORLD_SIZE'] = str(self.total_workers)
465+
466+
if not dist.is_initialized():
467+
dist.init_process_group(
468+
backend='nccl',
469+
world_size=self.total_workers,
470+
rank=self.worker_id
471+
)
472+
473+
self.distributed_initialized = True
474+
475+
# Set device using local device ID
476+
# Ray workers see their assigned GPU as device 0
477+
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
478+
torch.cuda.set_device(0)
479+
480+
def get_node_ip(self) -> str:
481+
"""Get the IP address of this worker node."""
482+
import socket
483+
return socket.gethostbyname(socket.gethostname())
414484

415485
def benchmark(
416486
self,
@@ -461,6 +531,7 @@ def benchmark(
461531
use_deep_gemm=use_deep_gemm,
462532
enable_expert_parallel=self.enable_expert_parallel,
463533
ep_size=self.ep_size,
534+
ep_rank=self.ep_rank,
464535
)
465536
return config, kernel_time
466537

@@ -515,6 +586,7 @@ def tune(
515586
use_deep_gemm=use_deep_gemm,
516587
enable_expert_parallel=self.enable_expert_parallel,
517588
ep_size=self.ep_size,
589+
ep_rank=self.ep_rank,
518590
)
519591
except triton.runtime.autotuner.OutOfResources:
520592
# Some configurations may be invalid and fail to compile.
@@ -581,6 +653,7 @@ def save_configs(
581653
)
582654

583655
print(f"Writing best config to {filename}...")
656+
584657
with open(filename, "w") as f:
585658
json.dump(configs, f, indent=4)
586659
f.write("\n")
@@ -692,11 +765,12 @@ def main(args: argparse.Namespace):
692765
if args.enable_expert_parallel:
693766
if args.tp_size != num_gpus:
694767
raise ValueError(
695-
"When running with --enable-expert-parallel, the specified "
696-
"--tp-size must be equal to the number of available GPUs. "
768+
"When running with --enable-expert-parallel, the EP size is "
769+
"automatically set to the number of available GPUs, and --tp-size "
770+
"must match this value for correct benchmarking. "
697771
f"Got --tp-size={args.tp_size} and {num_gpus} GPUs.\n"
698-
"To tune for a specific number of GPUs for expert parallel, "
699-
"please restrict the visible devices using the CUDA_VISIBLE_DEVICES"
772+
"To benchmark with a specific EP size, please restrict the visible "
773+
"devices using CUDA_VISIBLE_DEVICES environment variable."
700774
)
701775
if args.tp_size < 2:
702776
raise ValueError("Expert parallel requires tensor parallel size >= 2")
@@ -708,6 +782,23 @@ def main(args: argparse.Namespace):
708782
total_workers=num_gpus
709783
) for i in range(num_gpus)]
710784

785+
# Initialize distributed communication for expert parallel
786+
if args.enable_expert_parallel:
787+
# Get worker IPs to determine master
788+
worker_ips = ray.get([w.get_node_ip.remote() for w in workers])
789+
790+
# Use first worker's IP as master
791+
master_addr = worker_ips[0]
792+
master_port = get_open_port()
793+
794+
# Initialize distributed on all workers
795+
init_futures = [
796+
w.init_distributed.remote(master_addr, master_port)
797+
for w in workers
798+
]
799+
ray.get(init_futures)
800+
print(f"Initialized distributed environment with master at {master_addr}:{master_port}")
801+
711802
def _distribute(method: str, inputs: list[Any]) -> list[Any]:
712803
outputs = []
713804
worker_idx = 0

0 commit comments

Comments
 (0)