Skip to content

Commit 07825f8

Browse files
wconstabpobin6
authored and
pobin6
committed
[C10D] Support group_dst/group_src in c10d send/recv (pytorch#140460)
Partly addressing RFC 0042 (pytorch/rfcs#71) It's annoying that 'dst' (for send) ust be a global rank even when a group is passed in. But we can't easily change 'dst' without breaking existing cases. Furthermore, requiring use of 'global' dst breaks the less common usage pattern of creating a new ProcessGroup object that is not connected to the 'default group' and thus has no logical 'global' ranks. Pull Request resolved: pytorch#140460 Approved by: https://github.com/d4l3k, https://github.com/kwen2501, https://github.com/fduwjj
1 parent bf8851e commit 07825f8

File tree

3 files changed

+99
-52
lines changed

3 files changed

+99
-52
lines changed

test/distributed/test_c10d_common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1775,11 +1775,20 @@ def test_send_recv(self):
17751775

17761776
with self.assertRaises(ValueError):
17771777
dist.send(input_tensor, dist.get_rank())
1778+
with self.assertRaises(ValueError):
1779+
dist.send(input_tensor, group_dst=dist.get_rank())
1780+
1781+
with self.assertRaises(ValueError):
1782+
dist.send(input_tensor, dist.get_rank(), group_dst=dist.get_rank())
1783+
with self.assertRaises(ValueError):
1784+
dist.send(input_tensor)
17781785

17791786
# test recv
17801787
input_tensor = torch.zeros(2, 2)
17811788
dist.recv(input_tensor, (self.rank + 1) % self.world_size)
17821789
self.assertEqual(input_tensor, torch.zeros(2, 2) + 2)
1790+
with self.assertRaises(ValueError):
1791+
dist.recv(input_tensor, src=0, group_src=0)
17831792

17841793
dist.barrier()
17851794
# intentionally not calling into `destroy_process_group` as not all

test/distributed/test_c10d_nccl.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3825,8 +3825,9 @@ def test_reduce_subgroup(self):
38253825

38263826
@requires_nccl()
38273827
@skip_if_lt_x_gpu(4)
3828+
@parametrize("group_rank", [True, False])
38283829
@parametrize("async_op", [True, False])
3829-
def test_send_recv_subgroup(self, async_op):
3830+
def test_send_recv_subgroup(self, async_op, group_rank):
38303831
world_size = 4
38313832
if self.rank >= world_size:
38323833
return
@@ -3835,17 +3836,29 @@ def test_send_recv_subgroup(self, async_op):
38353836
if self.rank == 0 or self.rank == 2:
38363837
x = torch.empty((10,), device=device)
38373838
if async_op:
3838-
c10d.irecv(x, src=self.rank + 1, group=subgroup).wait()
3839+
if group_rank:
3840+
c10d.irecv(x, group_src=1, group=subgroup).wait()
3841+
else:
3842+
c10d.irecv(x, src=self.rank + 1, group=subgroup).wait()
38393843
else:
3840-
c10d.recv(x, src=self.rank + 1, group=subgroup)
3844+
if group_rank:
3845+
c10d.recv(x, group_src=1, group=subgroup)
3846+
else:
3847+
c10d.recv(x, src=self.rank + 1, group=subgroup)
38413848
expected = torch.ones((10,), device=device) * (self.rank + 1)
38423849
self.assertEqual(x, expected)
38433850
else:
38443851
x = torch.ones((10,), device=device) * self.rank
38453852
if async_op:
3846-
c10d.isend(x, dst=self.rank - 1, group=subgroup).wait()
3853+
if group_rank:
3854+
c10d.isend(x, group_dst=0, group=subgroup).wait()
3855+
else:
3856+
c10d.isend(x, dst=self.rank - 1, group=subgroup).wait()
38473857
else:
3848-
c10d.send(x, dst=self.rank - 1, group=subgroup)
3858+
if group_rank:
3859+
c10d.send(x, group_dst=0, group=subgroup)
3860+
else:
3861+
c10d.send(x, dst=self.rank - 1, group=subgroup)
38493862

38503863
@requires_nccl()
38513864
@skip_if_lt_x_gpu(4)

torch/distributed/distributed_c10d.py

Lines changed: 72 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,6 +1112,38 @@ def _check_tensor_list(param, param_name) -> None:
11121112
)
11131113

11141114

1115+
def _group_or_default_group(group: Optional[ProcessGroup] = None) -> ProcessGroup:
1116+
if group is None or group is GroupMember.WORLD:
1117+
group = _get_default_group()
1118+
return group
1119+
1120+
1121+
def _canonicalize_group_rank(
1122+
group: ProcessGroup,
1123+
global_rank: Optional[int] = None,
1124+
group_rank: Optional[int] = None,
1125+
) -> int:
1126+
"""
1127+
Helper method to take _either_ a global rank or a group rank and produce a group rank.
1128+
"""
1129+
if group_rank is not None:
1130+
if global_rank is not None:
1131+
raise ValueError("Can't specify both group_rank and global_rank")
1132+
else:
1133+
if global_rank is None:
1134+
raise ValueError("Must specify global_rank or group_rank")
1135+
group_rank = get_group_rank(group, global_rank)
1136+
return group_rank
1137+
1138+
1139+
def _check_not_self_rank(group: ProcessGroup, rank: int, rank_type: str):
1140+
if group.rank() == rank:
1141+
raise ValueError(
1142+
f"Invalid {rank_type} rank: {rank_type} rank should not be the same as "
1143+
"the rank of the current process."
1144+
)
1145+
1146+
11151147
def _as_iterable(obj) -> collections.abc.Iterable:
11161148
return obj if isinstance(obj, list) else (obj,)
11171149

@@ -2217,7 +2249,11 @@ def get_world_size(group: Optional[ProcessGroup] = None) -> int:
22172249

22182250

22192251
def isend(
2220-
tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0
2252+
tensor: torch.Tensor,
2253+
dst: Optional[int] = None,
2254+
group: Optional[ProcessGroup] = None,
2255+
tag: int = 0,
2256+
group_dst: Optional[int] = None,
22212257
) -> Optional[Work]:
22222258
"""
22232259
Send a tensor asynchronously.
@@ -2229,18 +2265,23 @@ def isend(
22292265
.. warning::
22302266
``tag`` is not supported with the NCCL backend.
22312267
2268+
Unlike send, which is blocking, isend allows src == dst rank, i.e. send to self.
2269+
22322270
Args:
22332271
tensor (Tensor): Tensor to send.
22342272
dst (int): Destination rank on global process group (regardless of ``group`` argument)
22352273
group (ProcessGroup, optional): The process group to work on. If None,
22362274
the default process group will be used.
22372275
tag (int, optional): Tag to match send with remote recv
2276+
group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst``
22382277
22392278
Returns:
22402279
A distributed request object.
22412280
None, if not part of the group
22422281
22432282
"""
2283+
group = _group_or_default_group(group)
2284+
group_dst = _canonicalize_group_rank(group, dst, group_dst)
22442285
_check_single_tensor(tensor, "tensor")
22452286
if _rank_not_in_group(group):
22462287
_warn_not_in_group("isend")
@@ -2249,34 +2290,32 @@ def isend(
22492290
if tensor.is_complex():
22502291
tensor = torch.view_as_real(tensor)
22512292

2252-
if group is None or group is GroupMember.WORLD:
2253-
pg = _get_default_group()
2254-
else:
2255-
pg = group
2256-
dst = get_group_rank(pg, dst)
2257-
2258-
return pg.send([tensor], dst, tag)
2293+
return group.send([tensor], group_dst, tag)
22592294

22602295

22612296
def irecv(
22622297
tensor: torch.Tensor,
22632298
src: Optional[int] = None,
22642299
group: Optional[ProcessGroup] = None,
22652300
tag: int = 0,
2301+
group_src: Optional[int] = None,
22662302
) -> Optional[Work]:
22672303
"""
22682304
Receives a tensor asynchronously.
22692305
22702306
.. warning::
22712307
``tag`` is not supported with the NCCL backend.
22722308
2309+
Unlike recv, which is blocking, irecv allows src == dst rank, i.e. recv from self.
2310+
22732311
Args:
22742312
tensor (Tensor): Tensor to fill with received data.
22752313
src (int, optional): Source rank on global process group (regardless of ``group`` argument).
22762314
Will receive from any process if unspecified.
22772315
group (ProcessGroup, optional): The process group to work on. If None,
22782316
the default process group will be used.
22792317
tag (int, optional): Tag to match recv with remote send
2318+
group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``.
22802319
22812320
Returns:
22822321
A distributed request object.
@@ -2291,24 +2330,21 @@ def irecv(
22912330
if tensor.is_complex():
22922331
tensor = torch.view_as_real(tensor)
22932332

2294-
if group is None or group is GroupMember.WORLD:
2295-
pg = _get_default_group()
2296-
else:
2297-
pg = group
2298-
2299-
if src is None:
2300-
return pg.recv_anysource([tensor], tag)
2333+
group = _group_or_default_group(group)
2334+
if src is None and group_src is None:
2335+
return group.recv_anysource([tensor], tag)
23012336
else:
2302-
if pg is GroupMember.WORLD:
2303-
return pg.recv([tensor], src, tag)
2304-
else:
2305-
group_src_rank = get_group_rank(pg, src)
2306-
return pg.recv([tensor], group_src_rank, tag)
2337+
group_src = _canonicalize_group_rank(group, src, group_src)
2338+
return group.recv([tensor], group_src, tag)
23072339

23082340

23092341
@_exception_logger
23102342
def send(
2311-
tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: int = 0
2343+
tensor: torch.Tensor,
2344+
dst: Optional[int] = None,
2345+
group: Optional[ProcessGroup] = None,
2346+
tag: int = 0,
2347+
group_dst: Optional[int] = None,
23122348
) -> None:
23132349
"""
23142350
Send a tensor synchronously.
@@ -2323,14 +2359,12 @@ def send(
23232359
group (ProcessGroup, optional): The process group to work on. If None,
23242360
the default process group will be used.
23252361
tag (int, optional): Tag to match send with remote recv
2362+
group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst``.
23262363
23272364
"""
2328-
if get_rank() == dst:
2329-
raise ValueError(
2330-
"Invalid destination rank: destination rank should not be the same as "
2331-
"the rank of the current process."
2332-
)
2333-
2365+
group = _group_or_default_group(group)
2366+
group_dst = _canonicalize_group_rank(group, dst, group_dst)
2367+
_check_not_self_rank(group, group_dst, "destination")
23342368
_check_single_tensor(tensor, "tensor")
23352369
if _rank_not_in_group(group):
23362370
_warn_not_in_group("send")
@@ -2339,12 +2373,7 @@ def send(
23392373
if tensor.is_complex():
23402374
tensor = torch.view_as_real(tensor)
23412375

2342-
if group is None or group is GroupMember.WORLD:
2343-
default_pg = _get_default_group()
2344-
default_pg.send([tensor], dst, tag).wait()
2345-
else:
2346-
group_dst_rank = get_group_rank(group, dst)
2347-
group.send([tensor], group_dst_rank, tag).wait()
2376+
group.send([tensor], group_dst, tag).wait()
23482377

23492378

23502379
@_exception_logger
@@ -2353,6 +2382,7 @@ def recv(
23532382
src: Optional[int] = None,
23542383
group: Optional[ProcessGroup] = None,
23552384
tag: int = 0,
2385+
group_src: Optional[int] = None,
23562386
) -> int:
23572387
"""
23582388
Receives a tensor synchronously.
@@ -2367,7 +2397,7 @@ def recv(
23672397
group (ProcessGroup, optional): The process group to work on. If None,
23682398
the default process group will be used.
23692399
tag (int, optional): Tag to match recv with remote send
2370-
2400+
group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``.
23712401
Returns:
23722402
Sender rank
23732403
-1, if not part of the group
@@ -2381,23 +2411,18 @@ def recv(
23812411
if tensor.is_complex():
23822412
tensor = torch.view_as_real(tensor)
23832413

2384-
pg = group or _get_default_group()
2414+
group = _group_or_default_group(group)
23852415

2386-
if src is None:
2387-
work = pg.recv_anysource([tensor], tag)
2416+
if src is None and group_src is None:
2417+
work = group.recv_anysource([tensor], tag)
23882418
work.wait()
23892419
src_rank = work._source_rank()
2390-
if group is None or group is GroupMember.WORLD:
2391-
return src_rank
2392-
else:
2393-
return get_global_rank(pg, src_rank)
2420+
return get_global_rank(group, src_rank)
23942421
else:
2395-
if group is None or group is GroupMember.WORLD:
2396-
pg.recv([tensor], src, tag).wait()
2397-
else:
2398-
group_src_rank = get_group_rank(pg, src)
2399-
pg.recv([tensor], group_src_rank, tag).wait()
2400-
return src
2422+
group_src = _canonicalize_group_rank(group, src, group_src)
2423+
_check_not_self_rank(group, group_src, "source")
2424+
group.recv([tensor], group_src, tag).wait()
2425+
return get_global_rank(group, group_src)
24012426

24022427

24032428
class _IllegalWork(Work):

0 commit comments

Comments
 (0)