diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 51c9f68e43a..31a7f3e2fea 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -3,6 +3,7 @@ import argparse import json +import os import time from contextlib import nullcontext from datetime import datetime @@ -11,17 +12,43 @@ import ray import torch +import torch.distributed as dist from ray.experimental.tqdm_ray import tqdm from vllm.model_executor.layers.fused_moe.fused_moe import * from vllm.platforms import current_platform from vllm.transformers_utils.config import get_config from vllm.triton_utils import triton -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, get_open_port FP8_DTYPE = current_platform.fp8_dtype() +def build_expert_map( + global_num_experts: int, ep_size: int, ep_rank: int +) -> tuple[int, torch.Tensor]: + """Build expert map for expert parallel. Returns (local_num_experts, expert_map).""" + # Calculate base number of experts per rank + base_experts = global_num_experts // ep_size + + # Create expert map + expert_map = torch.full((global_num_experts,), -1, dtype=torch.int32) + + if ep_rank < (ep_size - 1): + # Non-last ranks get base number of experts + local_num_experts = base_experts + start = ep_rank * base_experts + end = start + local_num_experts + expert_map[start:end] = torch.arange(local_num_experts, dtype=torch.int32) + else: + # Last rank gets all remaining experts + start = ep_rank * base_experts + local_num_experts = global_num_experts - start + expert_map[start:] = torch.arange(local_num_experts, dtype=torch.int32) + + return local_num_experts, expert_map.cuda() + + class BenchmarkConfig(TypedDict): BLOCK_SIZE_M: int BLOCK_SIZE_N: int @@ -44,15 +71,29 @@ def benchmark_config( num_iters: int = 100, block_quant_shape: list[int] = None, use_deep_gemm: bool = False, + enable_expert_parallel: bool = False, + ep_size: int = 1, + ep_rank: int = 0, ) -> float: init_dtype = torch.float16 if use_fp8_w8a8 else dtype x = torch.randn(num_tokens, hidden_size, dtype=dtype) + + # For expert parallel, calculate local and global expert counts + global_num_experts = num_experts + if enable_expert_parallel: + local_num_experts, expert_map = build_expert_map( + global_num_experts, ep_size, ep_rank + ) + else: + local_num_experts = num_experts + expert_map = None + if use_int8_w8a16: w1 = torch.randint( -127, 127, ( - num_experts, + local_num_experts, shard_intermediate_size, hidden_size, ), @@ -62,7 +103,7 @@ def benchmark_config( -127, 127, ( - num_experts, + local_num_experts, hidden_size, shard_intermediate_size // 2, ), @@ -70,12 +111,18 @@ def benchmark_config( ) else: w1 = torch.randn( - num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype + local_num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype ) w2 = torch.randn( - num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype + local_num_experts, + hidden_size, + shard_intermediate_size // 2, + dtype=init_dtype, ) - gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32) + # Gating output uses global number of experts + gating_output = torch.randn( + num_iters, num_tokens, global_num_experts, dtype=torch.float32 + ) w1_scale = None w2_scale = None @@ -83,16 +130,16 @@ def benchmark_config( a2_scale = None if use_int8_w8a16: w1_scale = torch.randn( - (num_experts, 2 * shard_intermediate_size), dtype=torch.float32 + (local_num_experts, 2 * shard_intermediate_size), dtype=torch.float32 ) - w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) + w2_scale = torch.randn((hidden_size, local_num_experts), dtype=torch.float32) if use_deep_gemm: # we use the default block shape for deepgemm block_quant_shape = [128, 128] if use_fp8_w8a8: if block_quant_shape: block_n, block_k = block_quant_shape[0], block_quant_shape[1] - E = num_experts + E = local_num_experts N = shard_intermediate_size // 2 K = hidden_size factor_for_scale = 1e-2 @@ -109,8 +156,8 @@ def benchmark_config( * factor_for_scale ) else: - w1_scale = torch.randn(num_experts, dtype=torch.float32) - w2_scale = torch.randn(num_experts, dtype=torch.float32) + w1_scale = torch.randn(local_num_experts, dtype=torch.float32) + w2_scale = torch.randn(local_num_experts, dtype=torch.float32) a1_scale = torch.randn(1, dtype=torch.float32) a2_scale = torch.randn(1, dtype=torch.float32) @@ -118,7 +165,7 @@ def benchmark_config( w1 = w1.to(FP8_DTYPE) w2 = w2.to(FP8_DTYPE) - input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) + input_gating = torch.empty(num_tokens, global_num_experts, dtype=torch.float32) def prepare(i: int): input_gating.copy_(gating_output[i]) @@ -145,6 +192,8 @@ def run(): a2_scale=a2_scale, block_shape=block_quant_shape, allow_deep_gemm=True, + global_num_experts=global_num_experts, + expert_map=expert_map, ) else: fused_moe( @@ -162,6 +211,8 @@ def run(): a1_scale=a1_scale, a2_scale=a2_scale, block_shape=block_quant_shape, + global_num_experts=global_num_experts, + expert_map=expert_map, ) # JIT compilation & warmup @@ -383,14 +434,56 @@ def merge_unique_dicts(list1, list2): @ray.remote(num_gpus=1) class BenchmarkWorker: - def __init__(self, seed: int) -> None: + def __init__( + self, + seed: int, + enable_expert_parallel: bool, + worker_id: int, + total_workers: int, + ) -> None: torch.set_default_device("cuda") current_platform.seed_everything(seed) self.seed = seed + self.enable_expert_parallel = enable_expert_parallel + self.worker_id = worker_id + self.total_workers = total_workers + self.ep_size = total_workers if enable_expert_parallel else 1 + self.ep_rank = worker_id if enable_expert_parallel else 0 # Get the device ID to allocate tensors and kernels # on the respective GPU. This is required for Ray to work # correctly with multi-GPU tuning on the ROCm platform. - self.device_id = int(ray.get_gpu_ids()[0]) + gpu_ids = ray.get_gpu_ids() + if gpu_ids: + self.device_id = int(gpu_ids[0]) + else: + self.device_id = 0 + self.distributed_initialized = False + + def init_distributed(self, master_addr: str, master_port: int) -> None: + """Initialize torch.distributed for expert parallel.""" + if self.distributed_initialized: + return + + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(master_port) + os.environ["RANK"] = str(self.worker_id) + os.environ["WORLD_SIZE"] = str(self.total_workers) + + if not dist.is_initialized(): + dist.init_process_group( + backend="nccl", world_size=self.total_workers, rank=self.worker_id + ) + + self.distributed_initialized = True + + # Set device using local device ID + # Ray workers see their assigned GPU as device 0 + if torch.cuda.is_available() and torch.cuda.device_count() > 0: + torch.cuda.set_device(0) + + def get_node_ip(self) -> str: + """Get the IP address of this worker node.""" + return ray.util.get_node_ip_address() def benchmark( self, @@ -439,6 +532,9 @@ def benchmark( num_iters=100, block_quant_shape=block_quant_shape, use_deep_gemm=use_deep_gemm, + enable_expert_parallel=self.enable_expert_parallel, + ep_size=self.ep_size, + ep_rank=self.ep_rank, ) return config, kernel_time @@ -491,6 +587,9 @@ def tune( num_iters=20, block_quant_shape=block_quant_shape, use_deep_gemm=use_deep_gemm, + enable_expert_parallel=self.enable_expert_parallel, + ep_size=self.ep_size, + ep_rank=self.ep_rank, ) except triton.runtime.autotuner.OutOfResources: # Some configurations may be invalid and fail to compile. @@ -535,6 +634,8 @@ def save_configs( use_fp8_w8a8: bool, use_int8_w8a16: bool, block_quant_shape: list[int], + enable_expert_parallel: bool = False, + ep_size: int = 1, ) -> None: dtype_str = get_config_dtype_str( dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 @@ -542,11 +643,23 @@ def save_configs( # NOTE(woosuk): The current naming convention uses w2.shape[2], which # is the intermediate size after silu_and_mul. - filename = get_config_file_name( - num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape - ) + + # vLLM uses local expert count in filename when EP is enabled + if enable_expert_parallel: + local_num_experts = num_experts // ep_size + filename = get_config_file_name( + local_num_experts, + shard_intermediate_size // 2, + dtype_str, + block_quant_shape, + ) + else: + filename = get_config_file_name( + num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape + ) print(f"Writing best config to {filename}...") + with open(filename, "w") as f: json.dump(configs, f, indent=4) f.write("\n") @@ -570,22 +683,18 @@ def main(args: argparse.Namespace): E = config.ffn_config.moe_num_experts topk = config.ffn_config.moe_top_k intermediate_size = config.ffn_config.ffn_hidden_size - shard_intermediate_size = 2 * intermediate_size // args.tp_size elif config.architectures[0] == "JambaForCausalLM": E = config.num_experts topk = config.num_experts_per_tok intermediate_size = config.intermediate_size - shard_intermediate_size = 2 * intermediate_size // args.tp_size elif config.architectures[0] in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"): E = config.n_routed_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size - shard_intermediate_size = 2 * intermediate_size // args.tp_size elif config.architectures[0] in ("Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"): E = config.num_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size - shard_intermediate_size = 2 * intermediate_size // args.tp_size else: # Support for llama4 config = config.get_text_config() @@ -593,6 +702,11 @@ def main(args: argparse.Namespace): E = config.num_local_experts topk = config.num_experts_per_tok intermediate_size = config.intermediate_size + + # Calculate shard_intermediate_size based on EP mode + if args.enable_expert_parallel: + shard_intermediate_size = 2 * intermediate_size + else: shard_intermediate_size = 2 * intermediate_size // args.tp_size hidden_size = config.hidden_size @@ -639,7 +753,47 @@ def main(args: argparse.Namespace): ray.init() num_gpus = int(ray.available_resources()["GPU"]) - workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] + + if args.enable_expert_parallel: + if args.tp_size != num_gpus: + raise ValueError( + "When running with --enable-expert-parallel, the specified " + "--tp-size must be equal to the number of available GPUs. " + f"Got --tp-size={args.tp_size} and {num_gpus} GPUs.\n" + "To tune for a specific number of GPUs for expert parallel, " + "please restrict the visible devices using the CUDA_VISIBLE_DEVICES" + ) + if args.tp_size < 2: + raise ValueError( + f"Expert parallel benchmark requires at least 2 GPUs, " + f"but got --tp-size={args.tp_size}." + ) + + workers = [ + BenchmarkWorker.remote( + args.seed, args.enable_expert_parallel, worker_id=i, total_workers=num_gpus + ) + for i in range(num_gpus) + ] + + # Initialize distributed communication for expert parallel + if args.enable_expert_parallel: + # Get worker IPs to determine master + worker_ips = ray.get([w.get_node_ip.remote() for w in workers]) + + # Use first worker's IP as master + master_addr = worker_ips[0] + master_port = get_open_port() + + # Initialize distributed on all workers + init_futures = [ + w.init_distributed.remote(master_addr, master_port) for w in workers + ] + ray.get(init_futures) + print( + f"Initialized distributed environment with master at " + f"{master_addr}:{master_port}" + ) def _distribute(method: str, inputs: list[Any]) -> list[Any]: outputs = [] @@ -690,6 +844,8 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]: use_fp8_w8a8, use_int8_w8a16, block_quant_shape, + args.enable_expert_parallel, + args.tp_size, ) end = time.time() print(f"Tuning took {end - start:.2f} seconds") @@ -735,6 +891,7 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]: parser.add_argument("--tune", action="store_true") parser.add_argument("--trust-remote-code", action="store_true") parser.add_argument("--model-prefix", type=str, required=False) + parser.add_argument("--enable-expert-parallel", action="store_true") args = parser.parse_args() main(args)