-
-
Notifications
You must be signed in to change notification settings - Fork 8.8k
[Benchmark] Add expert parallel support to MoE benchmark #20876
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -3,6 +3,7 @@ | |||||
|
||||||
import argparse | ||||||
import json | ||||||
import os | ||||||
import time | ||||||
from contextlib import nullcontext | ||||||
from datetime import datetime | ||||||
|
@@ -11,17 +12,43 @@ | |||||
|
||||||
import ray | ||||||
import torch | ||||||
import torch.distributed as dist | ||||||
from ray.experimental.tqdm_ray import tqdm | ||||||
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import * | ||||||
from vllm.platforms import current_platform | ||||||
from vllm.transformers_utils.config import get_config | ||||||
from vllm.triton_utils import triton | ||||||
from vllm.utils import FlexibleArgumentParser | ||||||
from vllm.utils import FlexibleArgumentParser, get_open_port | ||||||
|
||||||
FP8_DTYPE = current_platform.fp8_dtype() | ||||||
|
||||||
|
||||||
def build_expert_map(global_num_experts: int, | ||||||
ep_size: int, | ||||||
ep_rank: int) -> tuple[int, torch.Tensor]: | ||||||
"""Build expert map for expert parallel. Returns (local_num_experts, expert_map).""" | ||||||
# Calculate base number of experts per rank | ||||||
base_experts = global_num_experts // ep_size | ||||||
|
||||||
# Create expert map | ||||||
expert_map = torch.full((global_num_experts,), -1, dtype=torch.int32) | ||||||
|
||||||
if ep_rank < (ep_size - 1): | ||||||
# Non-last ranks get base number of experts | ||||||
local_num_experts = base_experts | ||||||
start = ep_rank * base_experts | ||||||
end = start + local_num_experts | ||||||
expert_map[start:end] = torch.arange(local_num_experts, dtype=torch.int32) | ||||||
else: | ||||||
# Last rank gets all remaining experts | ||||||
start = ep_rank * base_experts | ||||||
local_num_experts = global_num_experts - start | ||||||
expert_map[start:] = torch.arange(local_num_experts, dtype=torch.int32) | ||||||
|
||||||
return local_num_experts, expert_map.cuda() | ||||||
|
||||||
|
||||||
class BenchmarkConfig(TypedDict): | ||||||
BLOCK_SIZE_M: int | ||||||
BLOCK_SIZE_N: int | ||||||
|
@@ -44,15 +71,27 @@ | |||||
num_iters: int = 100, | ||||||
block_quant_shape: list[int] = None, | ||||||
use_deep_gemm: bool = False, | ||||||
enable_expert_parallel: bool = False, | ||||||
ep_size: int = 1, | ||||||
ep_rank: int = 0, | ||||||
) -> float: | ||||||
init_dtype = torch.float16 if use_fp8_w8a8 else dtype | ||||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype) | ||||||
|
||||||
# For expert parallel, calculate local and global expert counts | ||||||
global_num_experts = num_experts | ||||||
if enable_expert_parallel: | ||||||
local_num_experts, expert_map = build_expert_map(global_num_experts, ep_size, ep_rank) | ||||||
else: | ||||||
local_num_experts = num_experts | ||||||
expert_map = None | ||||||
|
||||||
if use_int8_w8a16: | ||||||
w1 = torch.randint( | ||||||
-127, | ||||||
127, | ||||||
( | ||||||
num_experts, | ||||||
local_num_experts, | ||||||
shard_intermediate_size, | ||||||
hidden_size, | ||||||
), | ||||||
|
@@ -62,37 +101,38 @@ | |||||
-127, | ||||||
127, | ||||||
( | ||||||
num_experts, | ||||||
local_num_experts, | ||||||
hidden_size, | ||||||
shard_intermediate_size // 2, | ||||||
), | ||||||
dtype=torch.int8, | ||||||
) | ||||||
else: | ||||||
w1 = torch.randn( | ||||||
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype | ||||||
local_num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype | ||||||
) | ||||||
w2 = torch.randn( | ||||||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype | ||||||
local_num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype | ||||||
) | ||||||
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32) | ||||||
# Gating output uses global number of experts | ||||||
gating_output = torch.randn(num_iters, num_tokens, global_num_experts, dtype=torch.float32) | ||||||
|
||||||
w1_scale = None | ||||||
w2_scale = None | ||||||
a1_scale = None | ||||||
a2_scale = None | ||||||
if use_int8_w8a16: | ||||||
w1_scale = torch.randn( | ||||||
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32 | ||||||
(local_num_experts, 2 * shard_intermediate_size), dtype=torch.float32 | ||||||
) | ||||||
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) | ||||||
w2_scale = torch.randn((hidden_size, local_num_experts), dtype=torch.float32) | ||||||
if use_deep_gemm: | ||||||
# we use the default block shape for deepgemm | ||||||
block_quant_shape = [128, 128] | ||||||
if use_fp8_w8a8: | ||||||
if block_quant_shape: | ||||||
block_n, block_k = block_quant_shape[0], block_quant_shape[1] | ||||||
E = num_experts | ||||||
E = local_num_experts | ||||||
N = shard_intermediate_size // 2 | ||||||
K = hidden_size | ||||||
factor_for_scale = 1e-2 | ||||||
|
@@ -109,16 +149,16 @@ | |||||
* factor_for_scale | ||||||
) | ||||||
else: | ||||||
w1_scale = torch.randn(num_experts, dtype=torch.float32) | ||||||
w2_scale = torch.randn(num_experts, dtype=torch.float32) | ||||||
w1_scale = torch.randn(local_num_experts, dtype=torch.float32) | ||||||
w2_scale = torch.randn(local_num_experts, dtype=torch.float32) | ||||||
|
||||||
a1_scale = torch.randn(1, dtype=torch.float32) | ||||||
a2_scale = torch.randn(1, dtype=torch.float32) | ||||||
|
||||||
w1 = w1.to(FP8_DTYPE) | ||||||
w2 = w2.to(FP8_DTYPE) | ||||||
|
||||||
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) | ||||||
input_gating = torch.empty(num_tokens, global_num_experts, dtype=torch.float32) | ||||||
|
||||||
def prepare(i: int): | ||||||
input_gating.copy_(gating_output[i]) | ||||||
|
@@ -145,6 +185,8 @@ | |||||
a2_scale=a2_scale, | ||||||
block_shape=block_quant_shape, | ||||||
allow_deep_gemm=True, | ||||||
global_num_experts=global_num_experts, | ||||||
expert_map=expert_map, | ||||||
) | ||||||
else: | ||||||
fused_moe( | ||||||
|
@@ -162,6 +204,8 @@ | |||||
a1_scale=a1_scale, | ||||||
a2_scale=a2_scale, | ||||||
block_shape=block_quant_shape, | ||||||
global_num_experts=global_num_experts, | ||||||
expert_map=expert_map, | ||||||
) | ||||||
|
||||||
# JIT compilation & warmup | ||||||
|
@@ -383,14 +427,59 @@ | |||||
|
||||||
@ray.remote(num_gpus=1) | ||||||
class BenchmarkWorker: | ||||||
def __init__(self, seed: int) -> None: | ||||||
def __init__( | ||||||
self, | ||||||
seed: int, | ||||||
enable_expert_parallel: bool, | ||||||
worker_id: int, | ||||||
total_workers: int | ||||||
) -> None: | ||||||
torch.set_default_device("cuda") | ||||||
current_platform.seed_everything(seed) | ||||||
self.seed = seed | ||||||
self.enable_expert_parallel = enable_expert_parallel | ||||||
self.worker_id = worker_id | ||||||
self.total_workers = total_workers | ||||||
self.ep_size = total_workers if enable_expert_parallel else 1 | ||||||
self.ep_rank = worker_id if enable_expert_parallel else 0 | ||||||
# Get the device ID to allocate tensors and kernels | ||||||
# on the respective GPU. This is required for Ray to work | ||||||
# correctly with multi-GPU tuning on the ROCm platform. | ||||||
self.device_id = int(ray.get_gpu_ids()[0]) | ||||||
gpu_ids = ray.get_gpu_ids() | ||||||
if gpu_ids: | ||||||
self.device_id = int(gpu_ids[0]) | ||||||
else: | ||||||
self.device_id = 0 | ||||||
self.distributed_initialized = False | ||||||
|
||||||
def init_distributed(self, master_addr: str, master_port: int) -> None: | ||||||
"""Initialize torch.distributed for expert parallel.""" | ||||||
if self.distributed_initialized: | ||||||
return | ||||||
|
||||||
os.environ['MASTER_ADDR'] = master_addr | ||||||
os.environ['MASTER_PORT'] = str(master_port) | ||||||
os.environ['RANK'] = str(self.worker_id) | ||||||
os.environ['WORLD_SIZE'] = str(self.total_workers) | ||||||
|
||||||
if not dist.is_initialized(): | ||||||
dist.init_process_group( | ||||||
backend='nccl', | ||||||
world_size=self.total_workers, | ||||||
rank=self.worker_id | ||||||
) | ||||||
|
||||||
self.distributed_initialized = True | ||||||
|
||||||
# Set device using local device ID | ||||||
# Ray workers see their assigned GPU as device 0 | ||||||
if torch.cuda.is_available() and torch.cuda.device_count() > 0: | ||||||
torch.cuda.set_device(0) | ||||||
|
||||||
def get_node_ip(self) -> str: | ||||||
"""Get the IP address of this worker node.""" | ||||||
import socket | ||||||
return socket.gethostbyname(socket.gethostname()) | ||||||
|
||||||
def benchmark( | ||||||
self, | ||||||
|
@@ -439,6 +528,9 @@ | |||||
num_iters=100, | ||||||
block_quant_shape=block_quant_shape, | ||||||
use_deep_gemm=use_deep_gemm, | ||||||
enable_expert_parallel=self.enable_expert_parallel, | ||||||
ep_size=self.ep_size, | ||||||
ep_rank=self.ep_rank, | ||||||
) | ||||||
return config, kernel_time | ||||||
|
||||||
|
@@ -491,6 +583,9 @@ | |||||
num_iters=20, | ||||||
block_quant_shape=block_quant_shape, | ||||||
use_deep_gemm=use_deep_gemm, | ||||||
enable_expert_parallel=self.enable_expert_parallel, | ||||||
ep_size=self.ep_size, | ||||||
ep_rank=self.ep_rank, | ||||||
) | ||||||
except triton.runtime.autotuner.OutOfResources: | ||||||
# Some configurations may be invalid and fail to compile. | ||||||
|
@@ -547,6 +642,7 @@ | |||||
) | ||||||
|
||||||
print(f"Writing best config to {filename}...") | ||||||
|
||||||
with open(filename, "w") as f: | ||||||
json.dump(configs, f, indent=4) | ||||||
f.write("\n") | ||||||
|
@@ -639,7 +735,42 @@ | |||||
|
||||||
ray.init() | ||||||
num_gpus = int(ray.available_resources()["GPU"]) | ||||||
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] | ||||||
|
||||||
if args.enable_expert_parallel: | ||||||
if args.tp_size != num_gpus: | ||||||
raise ValueError( | ||||||
"When running with --enable-expert-parallel, the specified " | ||||||
"--tp-size must be equal to the number of available GPUs. " | ||||||
f"Got --tp-size={args.tp_size} and {num_gpus} GPUs.\n" | ||||||
"To tune for a specific number of GPUs for expert parallel, " | ||||||
"please restrict the visible devices using the CUDA_VISIBLE_DEVICES" | ||||||
) | ||||||
if args.tp_size < 2: | ||||||
raise ValueError("Expert parallel requires tensor parallel size >= 2") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The error message "Expert parallel requires tensor parallel size >= 2" might be confusing. When
Suggested change
|
||||||
|
||||||
workers = [BenchmarkWorker.remote( | ||||||
args.seed, | ||||||
args.enable_expert_parallel, | ||||||
worker_id=i, | ||||||
total_workers=num_gpus | ||||||
) for i in range(num_gpus)] | ||||||
|
||||||
# Initialize distributed communication for expert parallel | ||||||
if args.enable_expert_parallel: | ||||||
# Get worker IPs to determine master | ||||||
worker_ips = ray.get([w.get_node_ip.remote() for w in workers]) | ||||||
|
||||||
# Use first worker's IP as master | ||||||
master_addr = worker_ips[0] | ||||||
master_port = get_open_port() | ||||||
|
||||||
# Initialize distributed on all workers | ||||||
init_futures = [ | ||||||
w.init_distributed.remote(master_addr, master_port) | ||||||
for w in workers | ||||||
] | ||||||
ray.get(init_futures) | ||||||
print(f"Initialized distributed environment with master at {master_addr}:{master_port}") | ||||||
|
||||||
def _distribute(method: str, inputs: list[Any]) -> list[Any]: | ||||||
outputs = [] | ||||||
|
@@ -735,6 +866,7 @@ | |||||
parser.add_argument("--tune", action="store_true") | ||||||
parser.add_argument("--trust-remote-code", action="store_true") | ||||||
parser.add_argument("--model-prefix", type=str, required=False) | ||||||
parser.add_argument("--enable-expert-parallel", action="store_true") | ||||||
args = parser.parse_args() | ||||||
|
||||||
main(args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current method for obtaining the node's IP address (
socket.gethostbyname(socket.gethostname())
) can be unreliable in environments with multiple network interfaces, which might lead to issues in setting up the distributed environment. Ray provides a more robust utility,ray.util.get_node_ip_address()
, which is specifically designed to correctly identify the node's IP within a Ray cluster. Using this utility would enhance the reliability of IP address resolution.