Skip to content

[Benchmark] Add expert parallel support to MoE benchmark #20876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 147 additions & 15 deletions benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import argparse
import json
import os
import time
from contextlib import nullcontext
from datetime import datetime
Expand All @@ -11,17 +12,43 @@

import ray
import torch
import torch.distributed as dist
from ray.experimental.tqdm_ray import tqdm

from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser
from vllm.utils import FlexibleArgumentParser, get_open_port

FP8_DTYPE = current_platform.fp8_dtype()


def build_expert_map(global_num_experts: int,
ep_size: int,
ep_rank: int) -> tuple[int, torch.Tensor]:
"""Build expert map for expert parallel. Returns (local_num_experts, expert_map)."""
# Calculate base number of experts per rank
base_experts = global_num_experts // ep_size

# Create expert map
expert_map = torch.full((global_num_experts,), -1, dtype=torch.int32)

if ep_rank < (ep_size - 1):
# Non-last ranks get base number of experts
local_num_experts = base_experts
start = ep_rank * base_experts
end = start + local_num_experts
expert_map[start:end] = torch.arange(local_num_experts, dtype=torch.int32)
else:
# Last rank gets all remaining experts
start = ep_rank * base_experts
local_num_experts = global_num_experts - start
expert_map[start:] = torch.arange(local_num_experts, dtype=torch.int32)

return local_num_experts, expert_map.cuda()


class BenchmarkConfig(TypedDict):
BLOCK_SIZE_M: int
BLOCK_SIZE_N: int
Expand All @@ -44,15 +71,27 @@
num_iters: int = 100,
block_quant_shape: list[int] = None,
use_deep_gemm: bool = False,
enable_expert_parallel: bool = False,
ep_size: int = 1,
ep_rank: int = 0,
) -> float:
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype)

# For expert parallel, calculate local and global expert counts
global_num_experts = num_experts
if enable_expert_parallel:
local_num_experts, expert_map = build_expert_map(global_num_experts, ep_size, ep_rank)

Check failure on line 84 in benchmarks/kernels/benchmark_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

benchmarks/kernels/benchmark_moe.py:84:89: E501 Line too long (94 > 88)
else:
local_num_experts = num_experts
expert_map = None

if use_int8_w8a16:
w1 = torch.randint(
-127,
127,
(
num_experts,
local_num_experts,
shard_intermediate_size,
hidden_size,
),
Expand All @@ -62,37 +101,38 @@
-127,
127,
(
num_experts,
local_num_experts,
hidden_size,
shard_intermediate_size // 2,
),
dtype=torch.int8,
)
else:
w1 = torch.randn(
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
local_num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
local_num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype

Check failure on line 115 in benchmarks/kernels/benchmark_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

benchmarks/kernels/benchmark_moe.py:115:89: E501 Line too long (90 > 88)
)
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
# Gating output uses global number of experts
gating_output = torch.randn(num_iters, num_tokens, global_num_experts, dtype=torch.float32)

Check failure on line 118 in benchmarks/kernels/benchmark_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

benchmarks/kernels/benchmark_moe.py:118:89: E501 Line too long (95 > 88)

w1_scale = None
w2_scale = None
a1_scale = None
a2_scale = None
if use_int8_w8a16:
w1_scale = torch.randn(
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
(local_num_experts, 2 * shard_intermediate_size), dtype=torch.float32
)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
w2_scale = torch.randn((hidden_size, local_num_experts), dtype=torch.float32)
if use_deep_gemm:
# we use the default block shape for deepgemm
block_quant_shape = [128, 128]
if use_fp8_w8a8:
if block_quant_shape:
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
E = num_experts
E = local_num_experts
N = shard_intermediate_size // 2
K = hidden_size
factor_for_scale = 1e-2
Expand All @@ -109,16 +149,16 @@
* factor_for_scale
)
else:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
w1_scale = torch.randn(local_num_experts, dtype=torch.float32)
w2_scale = torch.randn(local_num_experts, dtype=torch.float32)

a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)

w1 = w1.to(FP8_DTYPE)
w2 = w2.to(FP8_DTYPE)

input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
input_gating = torch.empty(num_tokens, global_num_experts, dtype=torch.float32)

def prepare(i: int):
input_gating.copy_(gating_output[i])
Expand All @@ -145,6 +185,8 @@
a2_scale=a2_scale,
block_shape=block_quant_shape,
allow_deep_gemm=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
else:
fused_moe(
Expand All @@ -162,6 +204,8 @@
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_quant_shape,
global_num_experts=global_num_experts,
expert_map=expert_map,
)

# JIT compilation & warmup
Expand Down Expand Up @@ -383,14 +427,59 @@

@ray.remote(num_gpus=1)
class BenchmarkWorker:
def __init__(self, seed: int) -> None:
def __init__(
self,
seed: int,
enable_expert_parallel: bool,
worker_id: int,
total_workers: int
) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(seed)
self.seed = seed
self.enable_expert_parallel = enable_expert_parallel
self.worker_id = worker_id
self.total_workers = total_workers
self.ep_size = total_workers if enable_expert_parallel else 1
self.ep_rank = worker_id if enable_expert_parallel else 0
# Get the device ID to allocate tensors and kernels
# on the respective GPU. This is required for Ray to work
# correctly with multi-GPU tuning on the ROCm platform.
self.device_id = int(ray.get_gpu_ids()[0])
gpu_ids = ray.get_gpu_ids()
if gpu_ids:
self.device_id = int(gpu_ids[0])
else:
self.device_id = 0
self.distributed_initialized = False

def init_distributed(self, master_addr: str, master_port: int) -> None:
"""Initialize torch.distributed for expert parallel."""
if self.distributed_initialized:
return

os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = str(master_port)
os.environ['RANK'] = str(self.worker_id)
os.environ['WORLD_SIZE'] = str(self.total_workers)

if not dist.is_initialized():
dist.init_process_group(
backend='nccl',
world_size=self.total_workers,
rank=self.worker_id
)

self.distributed_initialized = True

# Set device using local device ID
# Ray workers see their assigned GPU as device 0
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
torch.cuda.set_device(0)

def get_node_ip(self) -> str:
"""Get the IP address of this worker node."""
import socket
return socket.gethostbyname(socket.gethostname())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current method for obtaining the node's IP address (socket.gethostbyname(socket.gethostname())) can be unreliable in environments with multiple network interfaces, which might lead to issues in setting up the distributed environment. Ray provides a more robust utility, ray.util.get_node_ip_address(), which is specifically designed to correctly identify the node's IP within a Ray cluster. Using this utility would enhance the reliability of IP address resolution.

Suggested change
def get_node_ip(self) -> str:
"""Get the IP address of this worker node."""
import socket
return socket.gethostbyname(socket.gethostname())
def get_node_ip(self) -> str:
"""Get the IP address of this worker node."""
import ray.util
return ray.util.get_node_ip_address()


def benchmark(
self,
Expand Down Expand Up @@ -439,6 +528,9 @@
num_iters=100,
block_quant_shape=block_quant_shape,
use_deep_gemm=use_deep_gemm,
enable_expert_parallel=self.enable_expert_parallel,
ep_size=self.ep_size,
ep_rank=self.ep_rank,
)
return config, kernel_time

Expand Down Expand Up @@ -491,6 +583,9 @@
num_iters=20,
block_quant_shape=block_quant_shape,
use_deep_gemm=use_deep_gemm,
enable_expert_parallel=self.enable_expert_parallel,
ep_size=self.ep_size,
ep_rank=self.ep_rank,
)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
Expand Down Expand Up @@ -547,6 +642,7 @@
)

print(f"Writing best config to {filename}...")

with open(filename, "w") as f:
json.dump(configs, f, indent=4)
f.write("\n")
Expand Down Expand Up @@ -639,7 +735,42 @@

ray.init()
num_gpus = int(ray.available_resources()["GPU"])
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]

if args.enable_expert_parallel:
if args.tp_size != num_gpus:
raise ValueError(
"When running with --enable-expert-parallel, the specified "
"--tp-size must be equal to the number of available GPUs. "
f"Got --tp-size={args.tp_size} and {num_gpus} GPUs.\n"
"To tune for a specific number of GPUs for expert parallel, "
"please restrict the visible devices using the CUDA_VISIBLE_DEVICES"
)
if args.tp_size < 2:
raise ValueError("Expert parallel requires tensor parallel size >= 2")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error message "Expert parallel requires tensor parallel size >= 2" might be confusing. When --enable-expert-parallel is active, --tp-size effectively represents the number of GPUs used for expert parallelism, not tensor parallelism. To improve clarity and user experience, the message should explicitly refer to the GPU requirement.

Suggested change
raise ValueError("Expert parallel requires tensor parallel size >= 2")
raise ValueError(f"Expert parallel benchmark requires at least 2 GPUs, but got --tp-size={args.tp_size}.")


workers = [BenchmarkWorker.remote(
args.seed,
args.enable_expert_parallel,
worker_id=i,
total_workers=num_gpus
) for i in range(num_gpus)]

# Initialize distributed communication for expert parallel
if args.enable_expert_parallel:
# Get worker IPs to determine master
worker_ips = ray.get([w.get_node_ip.remote() for w in workers])

# Use first worker's IP as master
master_addr = worker_ips[0]
master_port = get_open_port()

# Initialize distributed on all workers
init_futures = [
w.init_distributed.remote(master_addr, master_port)
for w in workers
]
ray.get(init_futures)
print(f"Initialized distributed environment with master at {master_addr}:{master_port}")

Check failure on line 773 in benchmarks/kernels/benchmark_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

benchmarks/kernels/benchmark_moe.py:773:89: E501 Line too long (96 > 88)

def _distribute(method: str, inputs: list[Any]) -> list[Any]:
outputs = []
Expand Down Expand Up @@ -735,6 +866,7 @@
parser.add_argument("--tune", action="store_true")
parser.add_argument("--trust-remote-code", action="store_true")
parser.add_argument("--model-prefix", type=str, required=False)
parser.add_argument("--enable-expert-parallel", action="store_true")
args = parser.parse_args()

main(args)
Loading