Skip to content

[Benchmark] Add expert parallel support to MoE benchmark #20876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 179 additions & 22 deletions benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import argparse
import json
import os
import time
from contextlib import nullcontext
from datetime import datetime
Expand All @@ -11,17 +12,43 @@

import ray
import torch
import torch.distributed as dist
from ray.experimental.tqdm_ray import tqdm

from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser
from vllm.utils import FlexibleArgumentParser, get_open_port

FP8_DTYPE = current_platform.fp8_dtype()


def build_expert_map(
global_num_experts: int, ep_size: int, ep_rank: int
) -> tuple[int, torch.Tensor]:
"""Build expert map for expert parallel. Returns (local_num_experts, expert_map)."""
# Calculate base number of experts per rank
base_experts = global_num_experts // ep_size

# Create expert map
expert_map = torch.full((global_num_experts,), -1, dtype=torch.int32)

if ep_rank < (ep_size - 1):
# Non-last ranks get base number of experts
local_num_experts = base_experts
start = ep_rank * base_experts
end = start + local_num_experts
expert_map[start:end] = torch.arange(local_num_experts, dtype=torch.int32)
else:
# Last rank gets all remaining experts
start = ep_rank * base_experts
local_num_experts = global_num_experts - start
expert_map[start:] = torch.arange(local_num_experts, dtype=torch.int32)

return local_num_experts, expert_map.cuda()


class BenchmarkConfig(TypedDict):
BLOCK_SIZE_M: int
BLOCK_SIZE_N: int
Expand All @@ -44,15 +71,29 @@ def benchmark_config(
num_iters: int = 100,
block_quant_shape: list[int] = None,
use_deep_gemm: bool = False,
enable_expert_parallel: bool = False,
ep_size: int = 1,
ep_rank: int = 0,
) -> float:
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype)

# For expert parallel, calculate local and global expert counts
global_num_experts = num_experts
if enable_expert_parallel:
local_num_experts, expert_map = build_expert_map(
global_num_experts, ep_size, ep_rank
)
else:
local_num_experts = num_experts
expert_map = None

if use_int8_w8a16:
w1 = torch.randint(
-127,
127,
(
num_experts,
local_num_experts,
shard_intermediate_size,
hidden_size,
),
Expand All @@ -62,37 +103,43 @@ def benchmark_config(
-127,
127,
(
num_experts,
local_num_experts,
hidden_size,
shard_intermediate_size // 2,
),
dtype=torch.int8,
)
else:
w1 = torch.randn(
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
local_num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
local_num_experts,
hidden_size,
shard_intermediate_size // 2,
dtype=init_dtype,
)
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
# Gating output uses global number of experts
gating_output = torch.randn(
num_iters, num_tokens, global_num_experts, dtype=torch.float32
)

w1_scale = None
w2_scale = None
a1_scale = None
a2_scale = None
if use_int8_w8a16:
w1_scale = torch.randn(
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
(local_num_experts, 2 * shard_intermediate_size), dtype=torch.float32
)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
w2_scale = torch.randn((hidden_size, local_num_experts), dtype=torch.float32)
if use_deep_gemm:
# we use the default block shape for deepgemm
block_quant_shape = [128, 128]
if use_fp8_w8a8:
if block_quant_shape:
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
E = num_experts
E = local_num_experts
N = shard_intermediate_size // 2
K = hidden_size
factor_for_scale = 1e-2
Expand All @@ -109,16 +156,16 @@ def benchmark_config(
* factor_for_scale
)
else:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
w1_scale = torch.randn(local_num_experts, dtype=torch.float32)
w2_scale = torch.randn(local_num_experts, dtype=torch.float32)

a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)

w1 = w1.to(FP8_DTYPE)
w2 = w2.to(FP8_DTYPE)

input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
input_gating = torch.empty(num_tokens, global_num_experts, dtype=torch.float32)

def prepare(i: int):
input_gating.copy_(gating_output[i])
Expand All @@ -145,6 +192,8 @@ def run():
a2_scale=a2_scale,
block_shape=block_quant_shape,
allow_deep_gemm=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
else:
fused_moe(
Expand All @@ -162,6 +211,8 @@ def run():
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_quant_shape,
global_num_experts=global_num_experts,
expert_map=expert_map,
)

# JIT compilation & warmup
Expand Down Expand Up @@ -383,14 +434,56 @@ def merge_unique_dicts(list1, list2):

@ray.remote(num_gpus=1)
class BenchmarkWorker:
def __init__(self, seed: int) -> None:
def __init__(
self,
seed: int,
enable_expert_parallel: bool,
worker_id: int,
total_workers: int,
) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(seed)
self.seed = seed
self.enable_expert_parallel = enable_expert_parallel
self.worker_id = worker_id
self.total_workers = total_workers
self.ep_size = total_workers if enable_expert_parallel else 1
self.ep_rank = worker_id if enable_expert_parallel else 0
# Get the device ID to allocate tensors and kernels
# on the respective GPU. This is required for Ray to work
# correctly with multi-GPU tuning on the ROCm platform.
self.device_id = int(ray.get_gpu_ids()[0])
gpu_ids = ray.get_gpu_ids()
if gpu_ids:
self.device_id = int(gpu_ids[0])
else:
self.device_id = 0
self.distributed_initialized = False

def init_distributed(self, master_addr: str, master_port: int) -> None:
"""Initialize torch.distributed for expert parallel."""
if self.distributed_initialized:
return

os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = str(master_port)
os.environ["RANK"] = str(self.worker_id)
os.environ["WORLD_SIZE"] = str(self.total_workers)

if not dist.is_initialized():
dist.init_process_group(
backend="nccl", world_size=self.total_workers, rank=self.worker_id
)

self.distributed_initialized = True

# Set device using local device ID
# Ray workers see their assigned GPU as device 0
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
torch.cuda.set_device(0)

def get_node_ip(self) -> str:
"""Get the IP address of this worker node."""
return ray.util.get_node_ip_address()

def benchmark(
self,
Expand Down Expand Up @@ -439,6 +532,9 @@ def benchmark(
num_iters=100,
block_quant_shape=block_quant_shape,
use_deep_gemm=use_deep_gemm,
enable_expert_parallel=self.enable_expert_parallel,
ep_size=self.ep_size,
ep_rank=self.ep_rank,
)
return config, kernel_time

Expand Down Expand Up @@ -491,6 +587,9 @@ def tune(
num_iters=20,
block_quant_shape=block_quant_shape,
use_deep_gemm=use_deep_gemm,
enable_expert_parallel=self.enable_expert_parallel,
ep_size=self.ep_size,
ep_rank=self.ep_rank,
)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
Expand Down Expand Up @@ -535,18 +634,32 @@ def save_configs(
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
block_quant_shape: list[int],
enable_expert_parallel: bool = False,
ep_size: int = 1,
) -> None:
dtype_str = get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
)

# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename = get_config_file_name(
num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape
)

# vLLM uses local expert count in filename when EP is enabled
if enable_expert_parallel:
local_num_experts = num_experts // ep_size
filename = get_config_file_name(
local_num_experts,
shard_intermediate_size // 2,
dtype_str,
block_quant_shape,
)
else:
filename = get_config_file_name(
num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape
)

print(f"Writing best config to {filename}...")

with open(filename, "w") as f:
json.dump(configs, f, indent=4)
f.write("\n")
Expand All @@ -570,29 +683,30 @@ def main(args: argparse.Namespace):
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"):
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ("Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"):
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else:
# Support for llama4
config = config.get_text_config()
# Default: Mixtral.
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size

# Calculate shard_intermediate_size based on EP mode
if args.enable_expert_parallel:
shard_intermediate_size = 2 * intermediate_size
else:
shard_intermediate_size = 2 * intermediate_size // args.tp_size

hidden_size = config.hidden_size
Expand Down Expand Up @@ -639,7 +753,47 @@ def main(args: argparse.Namespace):

ray.init()
num_gpus = int(ray.available_resources()["GPU"])
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]

if args.enable_expert_parallel:
if args.tp_size != num_gpus:
raise ValueError(
"When running with --enable-expert-parallel, the specified "
"--tp-size must be equal to the number of available GPUs. "
f"Got --tp-size={args.tp_size} and {num_gpus} GPUs.\n"
"To tune for a specific number of GPUs for expert parallel, "
"please restrict the visible devices using the CUDA_VISIBLE_DEVICES"
)
if args.tp_size < 2:
raise ValueError(
f"Expert parallel benchmark requires at least 2 GPUs, "
f"but got --tp-size={args.tp_size}."
)

workers = [
BenchmarkWorker.remote(
args.seed, args.enable_expert_parallel, worker_id=i, total_workers=num_gpus
)
for i in range(num_gpus)
]

# Initialize distributed communication for expert parallel
if args.enable_expert_parallel:
# Get worker IPs to determine master
worker_ips = ray.get([w.get_node_ip.remote() for w in workers])

# Use first worker's IP as master
master_addr = worker_ips[0]
master_port = get_open_port()

# Initialize distributed on all workers
init_futures = [
w.init_distributed.remote(master_addr, master_port) for w in workers
]
ray.get(init_futures)
print(
f"Initialized distributed environment with master at "
f"{master_addr}:{master_port}"
)

def _distribute(method: str, inputs: list[Any]) -> list[Any]:
outputs = []
Expand Down Expand Up @@ -690,6 +844,8 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]:
use_fp8_w8a8,
use_int8_w8a16,
block_quant_shape,
args.enable_expert_parallel,
args.tp_size,
)
end = time.time()
print(f"Tuning took {end - start:.2f} seconds")
Expand Down Expand Up @@ -735,6 +891,7 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]:
parser.add_argument("--tune", action="store_true")
parser.add_argument("--trust-remote-code", action="store_true")
parser.add_argument("--model-prefix", type=str, required=False)
parser.add_argument("--enable-expert-parallel", action="store_true")
args = parser.parse_args()

main(args)