12
12
13
13
import ray
14
14
import torch
15
+ import torch .distributed as dist
15
16
from ray .experimental .tqdm_ray import tqdm
16
17
17
18
from vllm .model_executor .layers .fused_moe .fused_moe import *
18
19
from vllm .platforms import current_platform
19
20
from vllm .transformers_utils .config import get_config
20
21
from vllm .triton_utils import triton
21
- from vllm .utils import FlexibleArgumentParser
22
+ from vllm .utils import FlexibleArgumentParser , get_open_port
22
23
23
24
FP8_DTYPE = current_platform .fp8_dtype ()
24
25
25
26
27
+ def build_expert_map (global_num_experts : int ,
28
+ ep_size : int ,
29
+ ep_rank : int ) -> tuple [int , torch .Tensor ]:
30
+ """Build expert map for expert parallel. Returns (local_num_experts, expert_map)."""
31
+ # Calculate base number of experts per rank
32
+ base_experts = global_num_experts // ep_size
33
+
34
+ # Create expert map
35
+ expert_map = torch .full ((global_num_experts ,), - 1 , dtype = torch .int32 )
36
+
37
+ if ep_rank < (ep_size - 1 ):
38
+ # Non-last ranks get base number of experts
39
+ local_num_experts = base_experts
40
+ start = ep_rank * base_experts
41
+ end = start + local_num_experts
42
+ expert_map [start :end ] = torch .arange (local_num_experts , dtype = torch .int32 )
43
+ else :
44
+ # Last rank gets all remaining experts
45
+ start = ep_rank * base_experts
46
+ local_num_experts = global_num_experts - start
47
+ expert_map [start :] = torch .arange (local_num_experts , dtype = torch .int32 )
48
+
49
+ return local_num_experts , expert_map .cuda ()
50
+
51
+
26
52
class BenchmarkConfig (TypedDict ):
27
53
BLOCK_SIZE_M : int
28
54
BLOCK_SIZE_N : int
@@ -47,20 +73,25 @@ def benchmark_config(
47
73
use_deep_gemm : bool = False ,
48
74
enable_expert_parallel : bool = False ,
49
75
ep_size : int = 1 ,
76
+ ep_rank : int = 0 ,
50
77
) -> float :
51
78
init_dtype = torch .float16 if use_fp8_w8a8 else dtype
52
79
x = torch .randn (num_tokens , hidden_size , dtype = dtype )
53
80
54
- # For expert parallel, only create weights for local experts
81
+ # For expert parallel, calculate local and global expert counts
82
+ global_num_experts = num_experts
55
83
if enable_expert_parallel :
56
- num_experts = num_experts // ep_size
84
+ local_num_experts , expert_map = build_expert_map (global_num_experts , ep_size , ep_rank )
85
+ else :
86
+ local_num_experts = num_experts
87
+ expert_map = None
57
88
58
89
if use_int8_w8a16 :
59
90
w1 = torch .randint (
60
91
- 127 ,
61
92
127 ,
62
93
(
63
- num_experts ,
94
+ local_num_experts ,
64
95
shard_intermediate_size ,
65
96
hidden_size ,
66
97
),
@@ -70,37 +101,38 @@ def benchmark_config(
70
101
- 127 ,
71
102
127 ,
72
103
(
73
- num_experts ,
104
+ local_num_experts ,
74
105
hidden_size ,
75
106
shard_intermediate_size // 2 ,
76
107
),
77
108
dtype = torch .int8 ,
78
109
)
79
110
else :
80
111
w1 = torch .randn (
81
- num_experts , shard_intermediate_size , hidden_size , dtype = init_dtype
112
+ local_num_experts , shard_intermediate_size , hidden_size , dtype = init_dtype
82
113
)
83
114
w2 = torch .randn (
84
- num_experts , hidden_size , shard_intermediate_size // 2 , dtype = init_dtype
115
+ local_num_experts , hidden_size , shard_intermediate_size // 2 , dtype = init_dtype
85
116
)
86
- gating_output = torch .randn (num_iters , num_tokens , num_experts , dtype = torch .float32 )
117
+ # Gating output uses global number of experts
118
+ gating_output = torch .randn (num_iters , num_tokens , global_num_experts , dtype = torch .float32 )
87
119
88
120
w1_scale = None
89
121
w2_scale = None
90
122
a1_scale = None
91
123
a2_scale = None
92
124
if use_int8_w8a16 :
93
125
w1_scale = torch .randn (
94
- (num_experts , 2 * shard_intermediate_size ), dtype = torch .float32
126
+ (local_num_experts , 2 * shard_intermediate_size ), dtype = torch .float32
95
127
)
96
- w2_scale = torch .randn ((hidden_size , num_experts ), dtype = torch .float32 )
128
+ w2_scale = torch .randn ((hidden_size , local_num_experts ), dtype = torch .float32 )
97
129
if use_deep_gemm :
98
130
# we use the default block shape for deepgemm
99
131
block_quant_shape = [128 , 128 ]
100
132
if use_fp8_w8a8 :
101
133
if block_quant_shape :
102
134
block_n , block_k = block_quant_shape [0 ], block_quant_shape [1 ]
103
- E = num_experts
135
+ E = local_num_experts
104
136
N = shard_intermediate_size // 2
105
137
K = hidden_size
106
138
factor_for_scale = 1e-2
@@ -117,16 +149,16 @@ def benchmark_config(
117
149
* factor_for_scale
118
150
)
119
151
else :
120
- w1_scale = torch .randn (num_experts , dtype = torch .float32 )
121
- w2_scale = torch .randn (num_experts , dtype = torch .float32 )
152
+ w1_scale = torch .randn (local_num_experts , dtype = torch .float32 )
153
+ w2_scale = torch .randn (local_num_experts , dtype = torch .float32 )
122
154
123
155
a1_scale = torch .randn (1 , dtype = torch .float32 )
124
156
a2_scale = torch .randn (1 , dtype = torch .float32 )
125
157
126
158
w1 = w1 .to (FP8_DTYPE )
127
159
w2 = w2 .to (FP8_DTYPE )
128
160
129
- input_gating = torch .empty (num_tokens , num_experts , dtype = torch .float32 )
161
+ input_gating = torch .empty (num_tokens , global_num_experts , dtype = torch .float32 )
130
162
131
163
def prepare (i : int ):
132
164
input_gating .copy_ (gating_output [i ])
@@ -153,6 +185,8 @@ def run():
153
185
a2_scale = a2_scale ,
154
186
block_shape = block_quant_shape ,
155
187
allow_deep_gemm = True ,
188
+ global_num_experts = global_num_experts ,
189
+ expert_map = expert_map ,
156
190
)
157
191
else :
158
192
fused_moe (
@@ -170,6 +204,8 @@ def run():
170
204
a1_scale = a1_scale ,
171
205
a2_scale = a2_scale ,
172
206
block_shape = block_quant_shape ,
207
+ global_num_experts = global_num_experts ,
208
+ expert_map = expert_map ,
173
209
)
174
210
175
211
# JIT compilation & warmup
@@ -410,7 +446,41 @@ def __init__(
410
446
# Get the device ID to allocate tensors and kernels
411
447
# on the respective GPU. This is required for Ray to work
412
448
# correctly with multi-GPU tuning on the ROCm platform.
413
- self .device_id = int (ray .get_gpu_ids ()[0 ])
449
+ gpu_ids = ray .get_gpu_ids ()
450
+ if gpu_ids :
451
+ self .device_id = int (gpu_ids [0 ])
452
+ else :
453
+ self .device_id = 0
454
+ self .distributed_initialized = False
455
+
456
+ def init_distributed (self , master_addr : str , master_port : int ) -> None :
457
+ """Initialize torch.distributed for expert parallel."""
458
+ if self .distributed_initialized :
459
+ return
460
+
461
+ os .environ ['MASTER_ADDR' ] = master_addr
462
+ os .environ ['MASTER_PORT' ] = str (master_port )
463
+ os .environ ['RANK' ] = str (self .worker_id )
464
+ os .environ ['WORLD_SIZE' ] = str (self .total_workers )
465
+
466
+ if not dist .is_initialized ():
467
+ dist .init_process_group (
468
+ backend = 'nccl' ,
469
+ world_size = self .total_workers ,
470
+ rank = self .worker_id
471
+ )
472
+
473
+ self .distributed_initialized = True
474
+
475
+ # Set device using local device ID
476
+ # Ray workers see their assigned GPU as device 0
477
+ if torch .cuda .is_available () and torch .cuda .device_count () > 0 :
478
+ torch .cuda .set_device (0 )
479
+
480
+ def get_node_ip (self ) -> str :
481
+ """Get the IP address of this worker node."""
482
+ import socket
483
+ return socket .gethostbyname (socket .gethostname ())
414
484
415
485
def benchmark (
416
486
self ,
@@ -461,6 +531,7 @@ def benchmark(
461
531
use_deep_gemm = use_deep_gemm ,
462
532
enable_expert_parallel = self .enable_expert_parallel ,
463
533
ep_size = self .ep_size ,
534
+ ep_rank = self .ep_rank ,
464
535
)
465
536
return config , kernel_time
466
537
@@ -515,6 +586,7 @@ def tune(
515
586
use_deep_gemm = use_deep_gemm ,
516
587
enable_expert_parallel = self .enable_expert_parallel ,
517
588
ep_size = self .ep_size ,
589
+ ep_rank = self .ep_rank ,
518
590
)
519
591
except triton .runtime .autotuner .OutOfResources :
520
592
# Some configurations may be invalid and fail to compile.
@@ -581,6 +653,7 @@ def save_configs(
581
653
)
582
654
583
655
print (f"Writing best config to { filename } ..." )
656
+
584
657
with open (filename , "w" ) as f :
585
658
json .dump (configs , f , indent = 4 )
586
659
f .write ("\n " )
@@ -692,11 +765,12 @@ def main(args: argparse.Namespace):
692
765
if args .enable_expert_parallel :
693
766
if args .tp_size != num_gpus :
694
767
raise ValueError (
695
- "When running with --enable-expert-parallel, the specified "
696
- "--tp-size must be equal to the number of available GPUs. "
768
+ "When running with --enable-expert-parallel, the EP size is "
769
+ "automatically set to the number of available GPUs, and --tp-size "
770
+ "must match this value for correct benchmarking. "
697
771
f"Got --tp-size={ args .tp_size } and { num_gpus } GPUs.\n "
698
- "To tune for a specific number of GPUs for expert parallel, "
699
- "please restrict the visible devices using the CUDA_VISIBLE_DEVICES"
772
+ "To benchmark with a specific EP size, please restrict the visible "
773
+ "devices using CUDA_VISIBLE_DEVICES environment variable. "
700
774
)
701
775
if args .tp_size < 2 :
702
776
raise ValueError ("Expert parallel requires tensor parallel size >= 2" )
@@ -708,6 +782,23 @@ def main(args: argparse.Namespace):
708
782
total_workers = num_gpus
709
783
) for i in range (num_gpus )]
710
784
785
+ # Initialize distributed communication for expert parallel
786
+ if args .enable_expert_parallel :
787
+ # Get worker IPs to determine master
788
+ worker_ips = ray .get ([w .get_node_ip .remote () for w in workers ])
789
+
790
+ # Use first worker's IP as master
791
+ master_addr = worker_ips [0 ]
792
+ master_port = get_open_port ()
793
+
794
+ # Initialize distributed on all workers
795
+ init_futures = [
796
+ w .init_distributed .remote (master_addr , master_port )
797
+ for w in workers
798
+ ]
799
+ ray .get (init_futures )
800
+ print (f"Initialized distributed environment with master at { master_addr } :{ master_port } " )
801
+
711
802
def _distribute (method : str , inputs : list [Any ]) -> list [Any ]:
712
803
outputs = []
713
804
worker_idx = 0
0 commit comments