@@ -1112,6 +1112,38 @@ def _check_tensor_list(param, param_name) -> None:
1112
1112
)
1113
1113
1114
1114
1115
+ def _group_or_default_group (group : Optional [ProcessGroup ] = None ) -> ProcessGroup :
1116
+ if group is None or group is GroupMember .WORLD :
1117
+ group = _get_default_group ()
1118
+ return group
1119
+
1120
+
1121
+ def _canonicalize_group_rank (
1122
+ group : ProcessGroup ,
1123
+ global_rank : Optional [int ] = None ,
1124
+ group_rank : Optional [int ] = None ,
1125
+ ) -> int :
1126
+ """
1127
+ Helper method to take _either_ a global rank or a group rank and produce a group rank.
1128
+ """
1129
+ if group_rank is not None :
1130
+ if global_rank is not None :
1131
+ raise ValueError ("Can't specify both group_rank and global_rank" )
1132
+ else :
1133
+ if global_rank is None :
1134
+ raise ValueError ("Must specify global_rank or group_rank" )
1135
+ group_rank = get_group_rank (group , global_rank )
1136
+ return group_rank
1137
+
1138
+
1139
+ def _check_not_self_rank (group : ProcessGroup , rank : int , rank_type : str ):
1140
+ if group .rank () == rank :
1141
+ raise ValueError (
1142
+ f"Invalid { rank_type } rank: { rank_type } rank should not be the same as "
1143
+ "the rank of the current process."
1144
+ )
1145
+
1146
+
1115
1147
def _as_iterable (obj ) -> collections .abc .Iterable :
1116
1148
return obj if isinstance (obj , list ) else (obj ,)
1117
1149
@@ -2217,7 +2249,11 @@ def get_world_size(group: Optional[ProcessGroup] = None) -> int:
2217
2249
2218
2250
2219
2251
def isend (
2220
- tensor : torch .Tensor , dst : int , group : Optional [ProcessGroup ] = None , tag : int = 0
2252
+ tensor : torch .Tensor ,
2253
+ dst : Optional [int ] = None ,
2254
+ group : Optional [ProcessGroup ] = None ,
2255
+ tag : int = 0 ,
2256
+ group_dst : Optional [int ] = None ,
2221
2257
) -> Optional [Work ]:
2222
2258
"""
2223
2259
Send a tensor asynchronously.
@@ -2229,18 +2265,23 @@ def isend(
2229
2265
.. warning::
2230
2266
``tag`` is not supported with the NCCL backend.
2231
2267
2268
+ Unlike send, which is blocking, isend allows src == dst rank, i.e. send to self.
2269
+
2232
2270
Args:
2233
2271
tensor (Tensor): Tensor to send.
2234
2272
dst (int): Destination rank on global process group (regardless of ``group`` argument)
2235
2273
group (ProcessGroup, optional): The process group to work on. If None,
2236
2274
the default process group will be used.
2237
2275
tag (int, optional): Tag to match send with remote recv
2276
+ group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst``
2238
2277
2239
2278
Returns:
2240
2279
A distributed request object.
2241
2280
None, if not part of the group
2242
2281
2243
2282
"""
2283
+ group = _group_or_default_group (group )
2284
+ group_dst = _canonicalize_group_rank (group , dst , group_dst )
2244
2285
_check_single_tensor (tensor , "tensor" )
2245
2286
if _rank_not_in_group (group ):
2246
2287
_warn_not_in_group ("isend" )
@@ -2249,34 +2290,32 @@ def isend(
2249
2290
if tensor .is_complex ():
2250
2291
tensor = torch .view_as_real (tensor )
2251
2292
2252
- if group is None or group is GroupMember .WORLD :
2253
- pg = _get_default_group ()
2254
- else :
2255
- pg = group
2256
- dst = get_group_rank (pg , dst )
2257
-
2258
- return pg .send ([tensor ], dst , tag )
2293
+ return group .send ([tensor ], group_dst , tag )
2259
2294
2260
2295
2261
2296
def irecv (
2262
2297
tensor : torch .Tensor ,
2263
2298
src : Optional [int ] = None ,
2264
2299
group : Optional [ProcessGroup ] = None ,
2265
2300
tag : int = 0 ,
2301
+ group_src : Optional [int ] = None ,
2266
2302
) -> Optional [Work ]:
2267
2303
"""
2268
2304
Receives a tensor asynchronously.
2269
2305
2270
2306
.. warning::
2271
2307
``tag`` is not supported with the NCCL backend.
2272
2308
2309
+ Unlike recv, which is blocking, irecv allows src == dst rank, i.e. recv from self.
2310
+
2273
2311
Args:
2274
2312
tensor (Tensor): Tensor to fill with received data.
2275
2313
src (int, optional): Source rank on global process group (regardless of ``group`` argument).
2276
2314
Will receive from any process if unspecified.
2277
2315
group (ProcessGroup, optional): The process group to work on. If None,
2278
2316
the default process group will be used.
2279
2317
tag (int, optional): Tag to match recv with remote send
2318
+ group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``.
2280
2319
2281
2320
Returns:
2282
2321
A distributed request object.
@@ -2291,24 +2330,21 @@ def irecv(
2291
2330
if tensor .is_complex ():
2292
2331
tensor = torch .view_as_real (tensor )
2293
2332
2294
- if group is None or group is GroupMember .WORLD :
2295
- pg = _get_default_group ()
2296
- else :
2297
- pg = group
2298
-
2299
- if src is None :
2300
- return pg .recv_anysource ([tensor ], tag )
2333
+ group = _group_or_default_group (group )
2334
+ if src is None and group_src is None :
2335
+ return group .recv_anysource ([tensor ], tag )
2301
2336
else :
2302
- if pg is GroupMember .WORLD :
2303
- return pg .recv ([tensor ], src , tag )
2304
- else :
2305
- group_src_rank = get_group_rank (pg , src )
2306
- return pg .recv ([tensor ], group_src_rank , tag )
2337
+ group_src = _canonicalize_group_rank (group , src , group_src )
2338
+ return group .recv ([tensor ], group_src , tag )
2307
2339
2308
2340
2309
2341
@_exception_logger
2310
2342
def send (
2311
- tensor : torch .Tensor , dst : int , group : Optional [ProcessGroup ] = None , tag : int = 0
2343
+ tensor : torch .Tensor ,
2344
+ dst : Optional [int ] = None ,
2345
+ group : Optional [ProcessGroup ] = None ,
2346
+ tag : int = 0 ,
2347
+ group_dst : Optional [int ] = None ,
2312
2348
) -> None :
2313
2349
"""
2314
2350
Send a tensor synchronously.
@@ -2323,14 +2359,12 @@ def send(
2323
2359
group (ProcessGroup, optional): The process group to work on. If None,
2324
2360
the default process group will be used.
2325
2361
tag (int, optional): Tag to match send with remote recv
2362
+ group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst``.
2326
2363
2327
2364
"""
2328
- if get_rank () == dst :
2329
- raise ValueError (
2330
- "Invalid destination rank: destination rank should not be the same as "
2331
- "the rank of the current process."
2332
- )
2333
-
2365
+ group = _group_or_default_group (group )
2366
+ group_dst = _canonicalize_group_rank (group , dst , group_dst )
2367
+ _check_not_self_rank (group , group_dst , "destination" )
2334
2368
_check_single_tensor (tensor , "tensor" )
2335
2369
if _rank_not_in_group (group ):
2336
2370
_warn_not_in_group ("send" )
@@ -2339,12 +2373,7 @@ def send(
2339
2373
if tensor .is_complex ():
2340
2374
tensor = torch .view_as_real (tensor )
2341
2375
2342
- if group is None or group is GroupMember .WORLD :
2343
- default_pg = _get_default_group ()
2344
- default_pg .send ([tensor ], dst , tag ).wait ()
2345
- else :
2346
- group_dst_rank = get_group_rank (group , dst )
2347
- group .send ([tensor ], group_dst_rank , tag ).wait ()
2376
+ group .send ([tensor ], group_dst , tag ).wait ()
2348
2377
2349
2378
2350
2379
@_exception_logger
@@ -2353,6 +2382,7 @@ def recv(
2353
2382
src : Optional [int ] = None ,
2354
2383
group : Optional [ProcessGroup ] = None ,
2355
2384
tag : int = 0 ,
2385
+ group_src : Optional [int ] = None ,
2356
2386
) -> int :
2357
2387
"""
2358
2388
Receives a tensor synchronously.
@@ -2367,7 +2397,7 @@ def recv(
2367
2397
group (ProcessGroup, optional): The process group to work on. If None,
2368
2398
the default process group will be used.
2369
2399
tag (int, optional): Tag to match recv with remote send
2370
-
2400
+ group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``.
2371
2401
Returns:
2372
2402
Sender rank
2373
2403
-1, if not part of the group
@@ -2381,23 +2411,18 @@ def recv(
2381
2411
if tensor .is_complex ():
2382
2412
tensor = torch .view_as_real (tensor )
2383
2413
2384
- pg = group or _get_default_group ( )
2414
+ group = _group_or_default_group ( group )
2385
2415
2386
- if src is None :
2387
- work = pg .recv_anysource ([tensor ], tag )
2416
+ if src is None and group_src is None :
2417
+ work = group .recv_anysource ([tensor ], tag )
2388
2418
work .wait ()
2389
2419
src_rank = work ._source_rank ()
2390
- if group is None or group is GroupMember .WORLD :
2391
- return src_rank
2392
- else :
2393
- return get_global_rank (pg , src_rank )
2420
+ return get_global_rank (group , src_rank )
2394
2421
else :
2395
- if group is None or group is GroupMember .WORLD :
2396
- pg .recv ([tensor ], src , tag ).wait ()
2397
- else :
2398
- group_src_rank = get_group_rank (pg , src )
2399
- pg .recv ([tensor ], group_src_rank , tag ).wait ()
2400
- return src
2422
+ group_src = _canonicalize_group_rank (group , src , group_src )
2423
+ _check_not_self_rank (group , group_src , "source" )
2424
+ group .recv ([tensor ], group_src , tag ).wait ()
2425
+ return get_global_rank (group , group_src )
2401
2426
2402
2427
2403
2428
class _IllegalWork (Work ):
0 commit comments