@@ -1048,7 +1048,9 @@ def apply(
1048
1048
expert_map = expert_map ,
1049
1049
moe_all_to_all_group_name = self .moe_all_to_all_group_name ,
1050
1050
shared_experts = shared_experts )
1051
- elif fused_moe_state == FusedMoEState .AllGather :
1051
+ elif fused_moe_state in [
1052
+ FusedMoEState .AllGather , FusedMoEState .NaiveMulticast
1053
+ ]:
1052
1054
return fused_experts (hidden_states = x ,
1053
1055
w1 = layer .w13_weight ,
1054
1056
w2 = layer .w2_weight ,
@@ -1225,6 +1227,22 @@ def __init__(
1225
1227
self .tp_group = get_tp_group ().device_group
1226
1228
self .quant_method .create_weights (layer = self , ** moe_quant_params )
1227
1229
1230
+ def naive_multicast (self , x : torch .Tensor ,
1231
+ cu_tokens_across_dp_cpu : torch .Tensor ):
1232
+ assert (len (x .shape ) == 2 )
1233
+ buffer = torch .empty ((cu_tokens_across_dp_cpu [- 1 ], x .size (1 )),
1234
+ device = x .device ,
1235
+ dtype = x .dtype )
1236
+ start = 0 if self .dp_rank == 0 else cu_tokens_across_dp_cpu [
1237
+ self .dp_rank - 1 ]
1238
+ end = cu_tokens_across_dp_cpu [self .dp_rank ]
1239
+ buffer [start :end , :].copy_ (x )
1240
+ for idx in range (self .dp_size ):
1241
+ start = 0 if idx == 0 else cu_tokens_across_dp_cpu [idx - 1 ]
1242
+ end = cu_tokens_across_dp_cpu [idx ]
1243
+ get_dp_group ().broadcast (buffer [start :end , :], idx )
1244
+ return buffer
1245
+
1228
1246
def forward (self ,
1229
1247
hidden_states : torch .Tensor ,
1230
1248
router_logits : torch .Tensor ,
@@ -1250,9 +1268,10 @@ def forward(self,
1250
1268
shared_hidden_states = shared_experts (hidden_states )
1251
1269
1252
1270
tp_size = get_tensor_model_parallel_world_size ()
1253
- if (tp_size > 1 and fused_moe_state != FusedMoEState .AllGather
1254
- and fused_moe_state != FusedMoEState .AllGatherEP
1255
- and not replace_allreduce ):
1271
+ if (tp_size > 1 and fused_moe_state not in [
1272
+ FusedMoEState .AllGather , FusedMoEState .AllGatherEP ,
1273
+ FusedMoEState .NaiveMulticast
1274
+ ] and not replace_allreduce ):
1256
1275
if num_tokens < tp_size :
1257
1276
hidden_states = nn .functional .pad (
1258
1277
hidden_states , (0 , 0 , 0 , tp_size - num_tokens ))
@@ -1267,21 +1286,31 @@ def forward(self,
1267
1286
tp_rank = get_tensor_model_parallel_rank ()
1268
1287
hidden_states = chunk_hidden_states [tp_rank ]
1269
1288
router_logits = chunk_router_logits [tp_rank ]
1270
- if self .dp_size > 1 and fused_moe_state == FusedMoEState .AllGather :
1271
- # NOTE: When in torchair graph, it has been padded in model_runner_v1
1272
- if not self .torchair_graph_enabled or is_prefill :
1273
- attn_metadata = get_forward_context ().attn_metadata
1274
- if attn_metadata is not None :
1275
- max_num_tokens_across_dp = attn_metadata .max_num_tokens_across_dp
1276
- if num_tokens < max_num_tokens_across_dp :
1277
- hidden_states = nn .functional .pad (
1278
- hidden_states ,
1279
- (0 , 0 , 0 , max_num_tokens_across_dp - num_tokens ))
1280
- router_logits = nn .functional .pad (
1281
- router_logits ,
1282
- (0 , 0 , 0 , max_num_tokens_across_dp - num_tokens ))
1283
- hidden_states = get_dp_group ().all_gather (hidden_states , 0 )
1284
- router_logits = get_dp_group ().all_gather (router_logits , 0 )
1289
+ if self .dp_size > 1 :
1290
+ if fused_moe_state == FusedMoEState .AllGather :
1291
+ # NOTE: When in torchair graph, it has been padded in model_runner_v1
1292
+ if not self .torchair_graph_enabled :
1293
+ attn_metadata = get_forward_context ().attn_metadata
1294
+ if attn_metadata is not None :
1295
+ max_num_tokens_across_dp = attn_metadata .max_num_tokens_across_dp
1296
+ if num_tokens < max_num_tokens_across_dp :
1297
+ hidden_states = nn .functional .pad (
1298
+ hidden_states ,
1299
+ (0 , 0 , 0 ,
1300
+ max_num_tokens_across_dp - num_tokens ))
1301
+ router_logits = nn .functional .pad (
1302
+ router_logits ,
1303
+ (0 , 0 , 0 ,
1304
+ max_num_tokens_across_dp - num_tokens ))
1305
+ hidden_states = get_dp_group ().all_gather (hidden_states , 0 )
1306
+ router_logits = get_dp_group ().all_gather (router_logits , 0 )
1307
+ elif fused_moe_state == FusedMoEState .NaiveMulticast :
1308
+ cu_tokens_across_dp_cpu = get_forward_context (
1309
+ ).dp_metadata .cu_tokens_across_dp_cpu
1310
+ hidden_states = self .naive_multicast (hidden_states ,
1311
+ cu_tokens_across_dp_cpu )
1312
+ router_logits = self .naive_multicast (router_logits ,
1313
+ cu_tokens_across_dp_cpu )
1285
1314
1286
1315
# Matrix multiply.
1287
1316
e_hidden_states = self .quant_method .apply (
@@ -1310,28 +1339,40 @@ def forward(self,
1310
1339
if isinstance (e_hidden_states , tuple ):
1311
1340
e_hidden_states , shared_hidden_states = e_hidden_states
1312
1341
1313
- if (tp_size > 1 and fused_moe_state != FusedMoEState .AllGather
1314
- and fused_moe_state != FusedMoEState .AllGatherEP
1315
- and not replace_allreduce ):
1342
+ if (tp_size > 1 and fused_moe_state not in [
1343
+ FusedMoEState .AllGather , FusedMoEState .AllGatherEP ,
1344
+ FusedMoEState .NaiveMulticast
1345
+ ] and not replace_allreduce ):
1316
1346
dist .all_gather (list (chunk_hidden_states ), e_hidden_states ,
1317
1347
self .tp_group )
1318
1348
final_hidden_states = torch .cat (chunk_hidden_states , dim = 0 )
1319
1349
if num_tokens < tp_size :
1320
1350
final_hidden_states = final_hidden_states [:num_tokens ]
1321
1351
dispose_tensor (e_hidden_states )
1322
- elif self .dp_size > 1 and fused_moe_state == FusedMoEState .AllGather :
1323
- final_hidden_states = dist ._functional_collectives .reduce_scatter_tensor (
1324
- e_hidden_states ,
1325
- "sum" ,
1326
- scatter_dim = 0 ,
1327
- group = get_dp_group ().device_group )
1328
- final_hidden_states = final_hidden_states [:num_tokens ]
1329
- dispose_tensor (e_hidden_states )
1352
+ elif self .dp_size > 1 :
1353
+ if fused_moe_state == FusedMoEState .NaiveMulticast :
1354
+ start = 0 if self .dp_rank == 0 else cu_tokens_across_dp_cpu [
1355
+ self .dp_rank - 1 ]
1356
+ end = cu_tokens_across_dp_cpu [self .dp_rank ]
1357
+ final_hidden_states = get_dp_group ().all_reduce (
1358
+ e_hidden_states )
1359
+ final_hidden_states = final_hidden_states [start :end , :]
1360
+ dispose_tensor (e_hidden_states )
1361
+ elif fused_moe_state == FusedMoEState .AllGather :
1362
+ final_hidden_states = dist ._functional_collectives .reduce_scatter_tensor (
1363
+ e_hidden_states ,
1364
+ "sum" ,
1365
+ scatter_dim = 0 ,
1366
+ group = get_dp_group ().device_group )
1367
+ final_hidden_states = final_hidden_states [:num_tokens ]
1368
+ dispose_tensor (e_hidden_states )
1330
1369
else :
1331
1370
final_hidden_states = e_hidden_states
1332
1371
1333
- if tp_size > 1 and (fused_moe_state == FusedMoEState .AllGather
1334
- or fused_moe_state == FusedMoEState .AllGatherEP ):
1372
+ if tp_size > 1 and fused_moe_state in [
1373
+ FusedMoEState .AllGather , FusedMoEState .AllGatherEP ,
1374
+ FusedMoEState .NaiveMulticast
1375
+ ]:
1335
1376
final_hidden_states = tensor_model_parallel_all_reduce (
1336
1377
final_hidden_states )
1337
1378
0 commit comments