Skip to content

Commit 05cf211

Browse files
wconstabpobin6
authored and
pobin6
committed
[C10D] Support group_dst in scatter/gather (+object) ops (pytorch#140827)
Also add missing mypy typing and a few asserts to make mypy happy Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in pytorch#140460 Note: object collective version canonicalizes to global instead of group rank, simply becuase this left more of the original code intact and required less conversions overall. Pull Request resolved: pytorch#140827 Approved by: https://github.com/kwen2501
1 parent afdd34b commit 05cf211

File tree

3 files changed

+197
-91
lines changed

3 files changed

+197
-91
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 95 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3747,7 +3747,8 @@ def _init_two_pg2_subgroups(self, world_size: int = 4):
37473747

37483748
@requires_nccl()
37493749
@skip_if_lt_x_gpu(4)
3750-
def test_gather_subgroup(self):
3750+
@parametrize("group_rank", [True, False])
3751+
def test_gather_subgroup(self, group_rank):
37513752
world_size = 4
37523753
if self.rank >= world_size:
37533754
# just easier to write the test for exactly 4 gpus, even if this test class increased to 8gpu later
@@ -3758,28 +3759,48 @@ def test_gather_subgroup(self):
37583759
input = torch.ones((10,), device=device) * self.rank
37593760
if self.rank == 0 or self.rank == 2:
37603761
gather_list = [torch.empty_like(input) for _ in range(subgroup.size())]
3761-
torch.distributed.gather(
3762-
input,
3763-
gather_list=gather_list,
3764-
dst=self.rank,
3765-
group=subgroup,
3766-
async_op=False,
3767-
)
3762+
if group_rank:
3763+
# global_dst=0 group_dst=0 my_global_rank=2 gather_list is not None=True
3764+
torch.distributed.gather(
3765+
input,
3766+
gather_list=gather_list,
3767+
group_dst=0,
3768+
group=subgroup,
3769+
async_op=False,
3770+
)
3771+
else:
3772+
torch.distributed.gather(
3773+
input,
3774+
gather_list=gather_list,
3775+
dst=self.rank,
3776+
group=subgroup,
3777+
async_op=False,
3778+
)
37683779
for src in range(len(gather_list)):
37693780
expected = (torch.ones_like(input) * self.rank) + src
37703781
self.assertEqual(gather_list[src], expected)
37713782
else:
3772-
torch.distributed.gather(
3773-
input,
3774-
gather_list=None,
3775-
dst=self.rank - 1,
3776-
group=subgroup,
3777-
async_op=False,
3778-
)
3783+
if group_rank:
3784+
torch.distributed.gather(
3785+
input,
3786+
gather_list=None,
3787+
group_dst=0,
3788+
group=subgroup,
3789+
async_op=False,
3790+
)
3791+
else:
3792+
torch.distributed.gather(
3793+
input,
3794+
gather_list=None,
3795+
dst=self.rank - 1,
3796+
group=subgroup,
3797+
async_op=False,
3798+
)
37793799

37803800
@requires_nccl()
37813801
@skip_if_lt_x_gpu(4)
3782-
def test_gather_object_subgroup(self):
3802+
@parametrize("group_rank", [True, False])
3803+
def test_gather_object_subgroup(self, group_rank):
37833804
world_size = 4
37843805
if self.rank >= world_size:
37853806
# just easier to write the test for exactly 4 gpus, even if this test class increased to 8gpu later
@@ -3797,15 +3818,25 @@ def test_gather_object_subgroup(self):
37973818
# another weird thing- what's the point of making me specify some empty objects in my list?
37983819
# empty list should be valid imo. (but it throws an error)
37993820
gather_list = [{}, {}]
3800-
torch.distributed.gather_object(
3801-
input, object_gather_list=gather_list, dst=self.rank, group=subgroup
3802-
)
3821+
if group_rank:
3822+
torch.distributed.gather_object(
3823+
input, object_gather_list=gather_list, group_dst=0, group=subgroup
3824+
)
3825+
else:
3826+
torch.distributed.gather_object(
3827+
input, object_gather_list=gather_list, dst=self.rank, group=subgroup
3828+
)
38033829
for src in range(len(gather_list)):
38043830
self.assertEqual(gather_list[src]["rank"], self.rank + src)
38053831
else:
3806-
torch.distributed.gather_object(
3807-
input, object_gather_list=None, dst=self.rank - 1, group=subgroup
3808-
)
3832+
if group_rank:
3833+
torch.distributed.gather_object(
3834+
input, object_gather_list=None, group_dst=0, group=subgroup
3835+
)
3836+
else:
3837+
torch.distributed.gather_object(
3838+
input, object_gather_list=None, dst=self.rank - 1, group=subgroup
3839+
)
38093840

38103841
@requires_nccl()
38113842
@skip_if_lt_x_gpu(4)
@@ -3931,7 +3962,8 @@ def test_broadcast_object_list_subgroup(self, set_device: SetDeviceMethod):
39313962

39323963
@requires_nccl()
39333964
@skip_if_lt_x_gpu(4)
3934-
def test_scatter_subgroup(self):
3965+
@parametrize("group_rank", [True, False])
3966+
def test_scatter_subgroup(self, group_rank):
39353967
world_size = 4
39363968
if self.rank >= world_size:
39373969
return
@@ -3940,18 +3972,27 @@ def test_scatter_subgroup(self):
39403972
x = torch.empty((10,), device=device)
39413973
expected = torch.ones((10,), device=device) * self.rank
39423974
if self.rank == 0 or self.rank == 2:
3943-
c10d.scatter(x, scatter_list=None, src=self.rank + 1, group=subgroup)
3975+
if group_rank:
3976+
c10d.scatter(x, scatter_list=None, group_src=1, group=subgroup)
3977+
else:
3978+
c10d.scatter(x, scatter_list=None, src=self.rank + 1, group=subgroup)
39443979
else:
39453980
scatter_list = [
39463981
torch.ones((10,), device=device) * (self.rank - 1),
39473982
torch.ones((10,), device=device) * self.rank,
39483983
]
3949-
c10d.scatter(x, scatter_list=scatter_list, src=self.rank, group=subgroup)
3984+
if group_rank:
3985+
c10d.scatter(x, scatter_list=scatter_list, group_src=1, group=subgroup)
3986+
else:
3987+
c10d.scatter(
3988+
x, scatter_list=scatter_list, src=self.rank, group=subgroup
3989+
)
39503990
self.assertEqual(x, expected)
39513991

39523992
@requires_nccl()
39533993
@skip_if_lt_x_gpu(4)
3954-
def test_scatter_object_list_subgroup(self):
3994+
@parametrize("group_rank", [True, False])
3995+
def test_scatter_object_list_subgroup(self, group_rank):
39553996
world_size = 4
39563997
if self.rank >= world_size:
39573998
return
@@ -3960,24 +4001,40 @@ def test_scatter_object_list_subgroup(self):
39604001
scatter_object_output_list = [None]
39614002
expected = [{"rank": self.rank}]
39624003
if self.rank == 0 or self.rank == 2:
3963-
c10d.scatter_object_list(
3964-
scatter_object_output_list=scatter_object_output_list,
3965-
scatter_object_input_list=None,
3966-
src=self.rank + 1,
3967-
group=subgroup,
3968-
)
4004+
if group_rank:
4005+
c10d.scatter_object_list(
4006+
scatter_object_output_list=scatter_object_output_list,
4007+
scatter_object_input_list=None,
4008+
group_src=1,
4009+
group=subgroup,
4010+
)
4011+
else:
4012+
c10d.scatter_object_list(
4013+
scatter_object_output_list=scatter_object_output_list,
4014+
scatter_object_input_list=None,
4015+
src=self.rank + 1,
4016+
group=subgroup,
4017+
)
39694018

39704019
else:
39714020
scatter_object_input_list = [
39724021
{"rank": self.rank - 1},
39734022
{"rank": self.rank},
39744023
]
3975-
c10d.scatter_object_list(
3976-
scatter_object_output_list=scatter_object_output_list,
3977-
scatter_object_input_list=scatter_object_input_list,
3978-
src=self.rank,
3979-
group=subgroup,
3980-
)
4024+
if group_rank:
4025+
c10d.scatter_object_list(
4026+
scatter_object_output_list=scatter_object_output_list,
4027+
scatter_object_input_list=scatter_object_input_list,
4028+
group_src=1,
4029+
group=subgroup,
4030+
)
4031+
else:
4032+
c10d.scatter_object_list(
4033+
scatter_object_output_list=scatter_object_output_list,
4034+
scatter_object_input_list=scatter_object_input_list,
4035+
src=self.rank,
4036+
group=subgroup,
4037+
)
39814038
self.assertEqual(scatter_object_output_list, expected)
39824039

39834040

torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def shard(
132132
local_shards = []
133133
local_tensor = None
134134
local_metadata = None
135+
135136
tensors_to_scatter = cast(
136137
List[Optional[torch.Tensor]],
137138
[None] * dist.get_world_size(process_group),
@@ -192,9 +193,16 @@ def shard(
192193
process_group, src_for_scatter
193194
)
194195

196+
tensors_to_scatter_: Optional[List[torch.Tensor]] = None
197+
if current_rank == src_rank:
198+
tensors_to_scatter_ = []
199+
for t in tensors_to_scatter:
200+
assert isinstance(t, torch.Tensor)
201+
tensors_to_scatter_.append(t)
202+
195203
dist.scatter(
196204
local_tensor,
197-
scatter_list=tensors_to_scatter if current_rank == src_rank else None,
205+
scatter_list=tensors_to_scatter_,
198206
src=src_for_scatter,
199207
group=process_group,
200208
)

0 commit comments

Comments
 (0)