Skip to content

Commit 3f8f38b

Browse files
wconstabRyo-not-rio
authored andcommitted
[C10D] support group_src/dst in broadcast/reduce ops (pytorch#140843)
Also add mypy annotations Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in pytorch#140460 Pull Request resolved: pytorch#140843 Approved by: https://github.com/kwen2501
1 parent ea7a194 commit 3f8f38b

File tree

3 files changed

+89
-40
lines changed

3 files changed

+89
-40
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3840,7 +3840,8 @@ def test_gather_object_subgroup(self, group_rank):
38403840

38413841
@requires_nccl()
38423842
@skip_if_lt_x_gpu(4)
3843-
def test_reduce_subgroup(self):
3843+
@parametrize("group_rank", [True, False])
3844+
def test_reduce_subgroup(self, group_rank):
38443845
world_size = 4
38453846
if self.rank >= world_size:
38463847
return
@@ -3849,10 +3850,16 @@ def test_reduce_subgroup(self):
38493850
x = torch.ones((10,), device=device) * self.rank
38503851
if self.rank == 0 or self.rank == 2:
38513852
expected = x + torch.ones((10,), device=device) * (self.rank + 1)
3852-
c10d.reduce(x, dst=self.rank, group=subgroup, async_op=False)
3853+
if group_rank:
3854+
c10d.reduce(x, group_dst=0, group=subgroup, async_op=False)
3855+
else:
3856+
c10d.reduce(x, dst=self.rank, group=subgroup, async_op=False)
38533857
self.assertEqual(x, expected)
38543858
else:
3855-
c10d.reduce(x, dst=self.rank - 1, group=subgroup, async_op=False)
3859+
if group_rank:
3860+
c10d.reduce(x, group_dst=0, group=subgroup, async_op=False)
3861+
else:
3862+
c10d.reduce(x, dst=self.rank - 1, group=subgroup, async_op=False)
38563863

38573864
@requires_nccl()
38583865
@skip_if_lt_x_gpu(4)
@@ -3893,20 +3900,27 @@ def test_send_recv_subgroup(self, async_op, group_rank):
38933900

38943901
@requires_nccl()
38953902
@skip_if_lt_x_gpu(4)
3896-
def test_broadcast_subgroup(self):
3903+
@parametrize("group_rank", [True, False])
3904+
def test_broadcast_subgroup(self, group_rank):
38973905
world_size = 4
38983906
if self.rank >= world_size:
38993907
return
39003908
subgroup = self._init_two_pg2_subgroups(world_size)
39013909
device = torch.device("cuda:%d" % self.rank)
39023910
if self.rank == 0 or self.rank == 2:
39033911
x = torch.empty((10,), device=device)
3904-
c10d.broadcast(x, src=self.rank + 1, group=subgroup)
3912+
if group_rank:
3913+
c10d.broadcast(x, group_src=1, group=subgroup)
3914+
else:
3915+
c10d.broadcast(x, src=self.rank + 1, group=subgroup)
39053916
expected = torch.ones((10,), device=device) * (self.rank + 1)
39063917
self.assertEqual(x, expected)
39073918
else:
39083919
x = torch.ones((10,), device=device) * self.rank
3909-
c10d.broadcast(x, src=self.rank, group=subgroup)
3920+
if group_rank:
3921+
c10d.broadcast(x, group_src=1, group=subgroup)
3922+
else:
3923+
c10d.broadcast(x, src=self.rank, group=subgroup)
39103924

39113925
@requires_nccl()
39123926
@skip_if_lt_x_gpu(4)
@@ -3939,7 +3953,10 @@ def test_send_recv_object_list_subgroup(self, set_device: SetDeviceMethod):
39393953
"set_device",
39403954
[SetDeviceMethod.TORCH_CUDA_SET, SetDeviceMethod.COLLECTIVE_ARGUMENT],
39413955
)
3942-
def test_broadcast_object_list_subgroup(self, set_device: SetDeviceMethod):
3956+
@parametrize("group_rank", [True, False])
3957+
def test_broadcast_object_list_subgroup(
3958+
self, set_device: SetDeviceMethod, group_rank
3959+
):
39433960
world_size = 4
39443961
if self.rank >= world_size:
39453962
return
@@ -3951,14 +3968,26 @@ def test_broadcast_object_list_subgroup(self, set_device: SetDeviceMethod):
39513968
device = torch.device("cuda:%d" % self.rank)
39523969
if self.rank == 0 or self.rank == 2:
39533970
x = [{}]
3954-
c10d.broadcast_object_list(
3955-
x, src=self.rank + 1, group=subgroup, device=device
3956-
)
3971+
if group_rank:
3972+
c10d.broadcast_object_list(
3973+
x, group_src=1, group=subgroup, device=device
3974+
)
3975+
else:
3976+
c10d.broadcast_object_list(
3977+
x, src=self.rank + 1, group=subgroup, device=device
3978+
)
39573979
expected = [{"rank": self.rank + 1}]
39583980
self.assertEqual(x, expected)
39593981
else:
39603982
x = [{"rank": self.rank}]
3961-
c10d.broadcast_object_list(x, src=self.rank, group=subgroup, device=device)
3983+
if group_rank:
3984+
c10d.broadcast_object_list(
3985+
x, group_src=1, group=subgroup, device=device
3986+
)
3987+
else:
3988+
c10d.broadcast_object_list(
3989+
x, src=self.rank, group=subgroup, device=device
3990+
)
39623991

39633992
@requires_nccl()
39643993
@skip_if_lt_x_gpu(4)

torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,11 @@ def _broadcast_bucket(
8888
for assigned_rank in assigned_ranks:
8989
bucket_assignments = zero._bucket_assignments_per_rank[assigned_rank]
9090
if bucket_index in bucket_assignments:
91+
send_tensor = bucket_assignments[bucket_index].tensor
92+
assert send_tensor is not None
9193
overlap_info.broadcast_handles.append(
9294
dist.broadcast(
93-
bucket_assignments[bucket_index].tensor,
95+
send_tensor,
9496
src=dist.get_global_rank(zero.process_group, assigned_rank),
9597
group=zero.process_group,
9698
async_op=True,

torch/distributed/distributed_c10d.py

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2593,7 +2593,13 @@ def batch_isend_irecv(p2p_op_list):
25932593

25942594

25952595
@_exception_logger
2596-
def broadcast(tensor, src, group=None, async_op=False):
2596+
def broadcast(
2597+
tensor: torch.Tensor,
2598+
src: Optional[int] = None,
2599+
group: Optional[ProcessGroup] = None,
2600+
async_op: bool = False,
2601+
group_src: Optional[int] = None,
2602+
):
25972603
"""
25982604
Broadcasts the tensor to the whole group.
25992605
@@ -2607,29 +2613,26 @@ def broadcast(tensor, src, group=None, async_op=False):
26072613
group (ProcessGroup, optional): The process group to work on. If None,
26082614
the default process group will be used.
26092615
async_op (bool, optional): Whether this op should be an async op
2616+
group_src (int): Source rank on ``group``. Must specify one of ``group_src``
2617+
and ``src`` but not both.
26102618
26112619
Returns:
26122620
Async work handle, if async_op is set to True.
26132621
None, if not async_op or if not part of the group
26142622
26152623
"""
2624+
group = _group_or_default_group(group)
2625+
group_src = _canonicalize_group_rank(group, src, group_src, return_global=False)
26162626
_check_single_tensor(tensor, "tensor")
26172627
if _rank_not_in_group(group):
26182628
_warn_not_in_group("broadcast")
26192629
return
26202630

26212631
opts = BroadcastOptions()
2622-
opts.rootRank = src
2632+
opts.rootRank = group_src
26232633
opts.rootTensor = 0
26242634
opts.asyncOp = async_op
2625-
2626-
if group is None or group is GroupMember.WORLD:
2627-
default_pg = _get_default_group()
2628-
work = default_pg.broadcast([tensor], opts)
2629-
else:
2630-
group_src_rank = get_group_rank(group, src)
2631-
opts.rootRank = group_src_rank
2632-
work = group.broadcast([tensor], opts)
2635+
work = group.broadcast([tensor], opts)
26332636
if async_op:
26342637
return work
26352638
else:
@@ -2783,7 +2786,14 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False):
27832786

27842787

27852788
@_exception_logger
2786-
def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
2789+
def reduce(
2790+
tensor: torch.Tensor,
2791+
dst: Optional[int] = None,
2792+
op=ReduceOp.SUM,
2793+
group: Optional[ProcessGroup] = None,
2794+
async_op: bool = False,
2795+
group_dst: Optional[int] = None,
2796+
):
27872797
"""
27882798
Reduces the tensor data across all machines.
27892799
@@ -2799,29 +2809,25 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
27992809
group (ProcessGroup, optional): The process group to work on. If None,
28002810
the default process group will be used.
28012811
async_op (bool, optional): Whether this op should be an async op
2812+
group_dst (int): Destination rank on ``group``. Must specify one of ``group_dst``
2813+
and ``dst`` but not both.
28022814
28032815
Returns:
28042816
Async work handle, if async_op is set to True.
28052817
None, if not async_op or if not part of the group
28062818
28072819
"""
2820+
group = _group_or_default_group(group)
2821+
group_dst = _canonicalize_group_rank(group, dst, group_dst, return_global=False)
28082822
_check_single_tensor(tensor, "tensor")
28092823
if _rank_not_in_group(group):
28102824
_warn_not_in_group("reduce")
28112825
return
28122826

28132827
opts = ReduceOptions()
28142828
opts.reduceOp = op
2815-
opts.rootRank = dst
2816-
2817-
if group is None or group is GroupMember.WORLD:
2818-
default_pg = _get_default_group()
2819-
work = default_pg.reduce([tensor], opts)
2820-
else:
2821-
group_dst_rank = get_group_rank(group, dst)
2822-
opts.rootRank = group_dst_rank
2823-
work = group.reduce([tensor], opts)
2824-
2829+
opts.rootRank = group_dst
2830+
work = group.reduce([tensor], opts)
28252831
if async_op:
28262832
return work
28272833
else:
@@ -3270,7 +3276,13 @@ def recv_object_list(object_list, src=None, group=None, device=None):
32703276

32713277

32723278
@_exception_logger
3273-
def broadcast_object_list(object_list, src=0, group=None, device=None):
3279+
def broadcast_object_list(
3280+
object_list: List[Any],
3281+
src: Optional[int] = None,
3282+
group: Optional[ProcessGroup] = None,
3283+
device: Optional[torch.device] = None,
3284+
group_src: Optional[int] = None,
3285+
):
32743286
"""
32753287
Broadcasts picklable objects in ``object_list`` to the whole group.
32763288
@@ -3289,6 +3301,8 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
32893301
device (``torch.device``, optional): If not None, the objects are
32903302
serialized and converted to tensors which are moved to the
32913303
``device`` before broadcasting. Default is ``None``.
3304+
group_src (int): Source rank on ``group``. Must not specify one of ``group_src``
3305+
and ``src`` but not both.
32923306
32933307
Returns:
32943308
``None``. If rank is part of the group, ``object_list`` will contain the
@@ -3331,6 +3345,10 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
33313345
>>> objects
33323346
['foo', 12, {1: 2}]
33333347
"""
3348+
group = _group_or_default_group(group)
3349+
if src is None and group_src is None:
3350+
src = 0
3351+
global_src = _canonicalize_group_rank(group, src, group_src, return_global=True)
33343352
if _rank_not_in_group(group):
33353353
_warn_not_in_group("broadcast_object_list")
33363354
return
@@ -3342,9 +3360,9 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
33423360
# case it is not ``None`` we move the size and object tensors to be
33433361
# broadcasted to this device.
33443362
current_device = device or _get_object_coll_device(group)
3345-
my_rank = get_rank()
3363+
my_global_rank = get_rank()
33463364
# Serialize object_list elements to tensors on src rank.
3347-
if my_rank == src:
3365+
if my_global_rank == global_src:
33483366
tensor_list, size_list = zip(
33493367
*[_object_to_tensor(obj, current_device, group) for obj in object_list]
33503368
)
@@ -3355,12 +3373,12 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
33553373
)
33563374

33573375
# Broadcast object sizes
3358-
broadcast(object_sizes_tensor, src=src, group=group)
3376+
broadcast(object_sizes_tensor, src=global_src, group=group)
33593377

33603378
# Concatenate and broadcast serialized object tensors
33613379
# Note: torch.cat will do an extra memory copy to the current device, if the tensor_list
33623380
# has only one element, we can skip the copy.
3363-
if my_rank == src:
3381+
if my_global_rank == global_src:
33643382
if len(tensor_list) == 1: # type: ignore[possibly-undefined]
33653383
object_tensor = tensor_list[0]
33663384
else:
@@ -3372,10 +3390,10 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
33723390
device=current_device,
33733391
)
33743392

3375-
broadcast(object_tensor, src=src, group=group)
3393+
broadcast(object_tensor, src=global_src, group=group)
33763394
# Deserialize objects using their stored sizes.
33773395
offset = 0
3378-
if my_rank != src:
3396+
if my_global_rank != global_src:
33793397
for i, obj_size in enumerate(object_sizes_tensor):
33803398
obj_view = object_tensor[offset : offset + obj_size]
33813399
obj_view = obj_view.type(torch.uint8)

0 commit comments

Comments
 (0)