@@ -2593,7 +2593,13 @@ def batch_isend_irecv(p2p_op_list):
2593
2593
2594
2594
2595
2595
@_exception_logger
2596
- def broadcast (tensor , src , group = None , async_op = False ):
2596
+ def broadcast (
2597
+ tensor : torch .Tensor ,
2598
+ src : Optional [int ] = None ,
2599
+ group : Optional [ProcessGroup ] = None ,
2600
+ async_op : bool = False ,
2601
+ group_src : Optional [int ] = None ,
2602
+ ):
2597
2603
"""
2598
2604
Broadcasts the tensor to the whole group.
2599
2605
@@ -2607,29 +2613,26 @@ def broadcast(tensor, src, group=None, async_op=False):
2607
2613
group (ProcessGroup, optional): The process group to work on. If None,
2608
2614
the default process group will be used.
2609
2615
async_op (bool, optional): Whether this op should be an async op
2616
+ group_src (int): Source rank on ``group``. Must specify one of ``group_src``
2617
+ and ``src`` but not both.
2610
2618
2611
2619
Returns:
2612
2620
Async work handle, if async_op is set to True.
2613
2621
None, if not async_op or if not part of the group
2614
2622
2615
2623
"""
2624
+ group = _group_or_default_group (group )
2625
+ group_src = _canonicalize_group_rank (group , src , group_src , return_global = False )
2616
2626
_check_single_tensor (tensor , "tensor" )
2617
2627
if _rank_not_in_group (group ):
2618
2628
_warn_not_in_group ("broadcast" )
2619
2629
return
2620
2630
2621
2631
opts = BroadcastOptions ()
2622
- opts .rootRank = src
2632
+ opts .rootRank = group_src
2623
2633
opts .rootTensor = 0
2624
2634
opts .asyncOp = async_op
2625
-
2626
- if group is None or group is GroupMember .WORLD :
2627
- default_pg = _get_default_group ()
2628
- work = default_pg .broadcast ([tensor ], opts )
2629
- else :
2630
- group_src_rank = get_group_rank (group , src )
2631
- opts .rootRank = group_src_rank
2632
- work = group .broadcast ([tensor ], opts )
2635
+ work = group .broadcast ([tensor ], opts )
2633
2636
if async_op :
2634
2637
return work
2635
2638
else :
@@ -2783,7 +2786,14 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False):
2783
2786
2784
2787
2785
2788
@_exception_logger
2786
- def reduce (tensor , dst , op = ReduceOp .SUM , group = None , async_op = False ):
2789
+ def reduce (
2790
+ tensor : torch .Tensor ,
2791
+ dst : Optional [int ] = None ,
2792
+ op = ReduceOp .SUM ,
2793
+ group : Optional [ProcessGroup ] = None ,
2794
+ async_op : bool = False ,
2795
+ group_dst : Optional [int ] = None ,
2796
+ ):
2787
2797
"""
2788
2798
Reduces the tensor data across all machines.
2789
2799
@@ -2799,29 +2809,25 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
2799
2809
group (ProcessGroup, optional): The process group to work on. If None,
2800
2810
the default process group will be used.
2801
2811
async_op (bool, optional): Whether this op should be an async op
2812
+ group_dst (int): Destination rank on ``group``. Must specify one of ``group_dst``
2813
+ and ``dst`` but not both.
2802
2814
2803
2815
Returns:
2804
2816
Async work handle, if async_op is set to True.
2805
2817
None, if not async_op or if not part of the group
2806
2818
2807
2819
"""
2820
+ group = _group_or_default_group (group )
2821
+ group_dst = _canonicalize_group_rank (group , dst , group_dst , return_global = False )
2808
2822
_check_single_tensor (tensor , "tensor" )
2809
2823
if _rank_not_in_group (group ):
2810
2824
_warn_not_in_group ("reduce" )
2811
2825
return
2812
2826
2813
2827
opts = ReduceOptions ()
2814
2828
opts .reduceOp = op
2815
- opts .rootRank = dst
2816
-
2817
- if group is None or group is GroupMember .WORLD :
2818
- default_pg = _get_default_group ()
2819
- work = default_pg .reduce ([tensor ], opts )
2820
- else :
2821
- group_dst_rank = get_group_rank (group , dst )
2822
- opts .rootRank = group_dst_rank
2823
- work = group .reduce ([tensor ], opts )
2824
-
2829
+ opts .rootRank = group_dst
2830
+ work = group .reduce ([tensor ], opts )
2825
2831
if async_op :
2826
2832
return work
2827
2833
else :
@@ -3270,7 +3276,13 @@ def recv_object_list(object_list, src=None, group=None, device=None):
3270
3276
3271
3277
3272
3278
@_exception_logger
3273
- def broadcast_object_list (object_list , src = 0 , group = None , device = None ):
3279
+ def broadcast_object_list (
3280
+ object_list : List [Any ],
3281
+ src : Optional [int ] = None ,
3282
+ group : Optional [ProcessGroup ] = None ,
3283
+ device : Optional [torch .device ] = None ,
3284
+ group_src : Optional [int ] = None ,
3285
+ ):
3274
3286
"""
3275
3287
Broadcasts picklable objects in ``object_list`` to the whole group.
3276
3288
@@ -3289,6 +3301,8 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
3289
3301
device (``torch.device``, optional): If not None, the objects are
3290
3302
serialized and converted to tensors which are moved to the
3291
3303
``device`` before broadcasting. Default is ``None``.
3304
+ group_src (int): Source rank on ``group``. Must not specify one of ``group_src``
3305
+ and ``src`` but not both.
3292
3306
3293
3307
Returns:
3294
3308
``None``. If rank is part of the group, ``object_list`` will contain the
@@ -3331,6 +3345,10 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
3331
3345
>>> objects
3332
3346
['foo', 12, {1: 2}]
3333
3347
"""
3348
+ group = _group_or_default_group (group )
3349
+ if src is None and group_src is None :
3350
+ src = 0
3351
+ global_src = _canonicalize_group_rank (group , src , group_src , return_global = True )
3334
3352
if _rank_not_in_group (group ):
3335
3353
_warn_not_in_group ("broadcast_object_list" )
3336
3354
return
@@ -3342,9 +3360,9 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
3342
3360
# case it is not ``None`` we move the size and object tensors to be
3343
3361
# broadcasted to this device.
3344
3362
current_device = device or _get_object_coll_device (group )
3345
- my_rank = get_rank ()
3363
+ my_global_rank = get_rank ()
3346
3364
# Serialize object_list elements to tensors on src rank.
3347
- if my_rank == src :
3365
+ if my_global_rank == global_src :
3348
3366
tensor_list , size_list = zip (
3349
3367
* [_object_to_tensor (obj , current_device , group ) for obj in object_list ]
3350
3368
)
@@ -3355,12 +3373,12 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
3355
3373
)
3356
3374
3357
3375
# Broadcast object sizes
3358
- broadcast (object_sizes_tensor , src = src , group = group )
3376
+ broadcast (object_sizes_tensor , src = global_src , group = group )
3359
3377
3360
3378
# Concatenate and broadcast serialized object tensors
3361
3379
# Note: torch.cat will do an extra memory copy to the current device, if the tensor_list
3362
3380
# has only one element, we can skip the copy.
3363
- if my_rank == src :
3381
+ if my_global_rank == global_src :
3364
3382
if len (tensor_list ) == 1 : # type: ignore[possibly-undefined]
3365
3383
object_tensor = tensor_list [0 ]
3366
3384
else :
@@ -3372,10 +3390,10 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
3372
3390
device = current_device ,
3373
3391
)
3374
3392
3375
- broadcast (object_tensor , src = src , group = group )
3393
+ broadcast (object_tensor , src = global_src , group = group )
3376
3394
# Deserialize objects using their stored sizes.
3377
3395
offset = 0
3378
- if my_rank != src :
3396
+ if my_global_rank != global_src :
3379
3397
for i , obj_size in enumerate (object_sizes_tensor ):
3380
3398
obj_view = object_tensor [offset : offset + obj_size ]
3381
3399
obj_view = obj_view .type (torch .uint8 )
0 commit comments