@@ -3747,7 +3747,8 @@ def _init_two_pg2_subgroups(self, world_size: int = 4):
3747
3747
3748
3748
@requires_nccl ()
3749
3749
@skip_if_lt_x_gpu (4 )
3750
- def test_gather_subgroup (self ):
3750
+ @parametrize ("group_rank" , [True , False ])
3751
+ def test_gather_subgroup (self , group_rank ):
3751
3752
world_size = 4
3752
3753
if self .rank >= world_size :
3753
3754
# just easier to write the test for exactly 4 gpus, even if this test class increased to 8gpu later
@@ -3758,28 +3759,48 @@ def test_gather_subgroup(self):
3758
3759
input = torch .ones ((10 ,), device = device ) * self .rank
3759
3760
if self .rank == 0 or self .rank == 2 :
3760
3761
gather_list = [torch .empty_like (input ) for _ in range (subgroup .size ())]
3761
- torch .distributed .gather (
3762
- input ,
3763
- gather_list = gather_list ,
3764
- dst = self .rank ,
3765
- group = subgroup ,
3766
- async_op = False ,
3767
- )
3762
+ if group_rank :
3763
+ # global_dst=0 group_dst=0 my_global_rank=2 gather_list is not None=True
3764
+ torch .distributed .gather (
3765
+ input ,
3766
+ gather_list = gather_list ,
3767
+ group_dst = 0 ,
3768
+ group = subgroup ,
3769
+ async_op = False ,
3770
+ )
3771
+ else :
3772
+ torch .distributed .gather (
3773
+ input ,
3774
+ gather_list = gather_list ,
3775
+ dst = self .rank ,
3776
+ group = subgroup ,
3777
+ async_op = False ,
3778
+ )
3768
3779
for src in range (len (gather_list )):
3769
3780
expected = (torch .ones_like (input ) * self .rank ) + src
3770
3781
self .assertEqual (gather_list [src ], expected )
3771
3782
else :
3772
- torch .distributed .gather (
3773
- input ,
3774
- gather_list = None ,
3775
- dst = self .rank - 1 ,
3776
- group = subgroup ,
3777
- async_op = False ,
3778
- )
3783
+ if group_rank :
3784
+ torch .distributed .gather (
3785
+ input ,
3786
+ gather_list = None ,
3787
+ group_dst = 0 ,
3788
+ group = subgroup ,
3789
+ async_op = False ,
3790
+ )
3791
+ else :
3792
+ torch .distributed .gather (
3793
+ input ,
3794
+ gather_list = None ,
3795
+ dst = self .rank - 1 ,
3796
+ group = subgroup ,
3797
+ async_op = False ,
3798
+ )
3779
3799
3780
3800
@requires_nccl ()
3781
3801
@skip_if_lt_x_gpu (4 )
3782
- def test_gather_object_subgroup (self ):
3802
+ @parametrize ("group_rank" , [True , False ])
3803
+ def test_gather_object_subgroup (self , group_rank ):
3783
3804
world_size = 4
3784
3805
if self .rank >= world_size :
3785
3806
# just easier to write the test for exactly 4 gpus, even if this test class increased to 8gpu later
@@ -3797,15 +3818,25 @@ def test_gather_object_subgroup(self):
3797
3818
# another weird thing- what's the point of making me specify some empty objects in my list?
3798
3819
# empty list should be valid imo. (but it throws an error)
3799
3820
gather_list = [{}, {}]
3800
- torch .distributed .gather_object (
3801
- input , object_gather_list = gather_list , dst = self .rank , group = subgroup
3802
- )
3821
+ if group_rank :
3822
+ torch .distributed .gather_object (
3823
+ input , object_gather_list = gather_list , group_dst = 0 , group = subgroup
3824
+ )
3825
+ else :
3826
+ torch .distributed .gather_object (
3827
+ input , object_gather_list = gather_list , dst = self .rank , group = subgroup
3828
+ )
3803
3829
for src in range (len (gather_list )):
3804
3830
self .assertEqual (gather_list [src ]["rank" ], self .rank + src )
3805
3831
else :
3806
- torch .distributed .gather_object (
3807
- input , object_gather_list = None , dst = self .rank - 1 , group = subgroup
3808
- )
3832
+ if group_rank :
3833
+ torch .distributed .gather_object (
3834
+ input , object_gather_list = None , group_dst = 0 , group = subgroup
3835
+ )
3836
+ else :
3837
+ torch .distributed .gather_object (
3838
+ input , object_gather_list = None , dst = self .rank - 1 , group = subgroup
3839
+ )
3809
3840
3810
3841
@requires_nccl ()
3811
3842
@skip_if_lt_x_gpu (4 )
@@ -3931,7 +3962,8 @@ def test_broadcast_object_list_subgroup(self, set_device: SetDeviceMethod):
3931
3962
3932
3963
@requires_nccl ()
3933
3964
@skip_if_lt_x_gpu (4 )
3934
- def test_scatter_subgroup (self ):
3965
+ @parametrize ("group_rank" , [True , False ])
3966
+ def test_scatter_subgroup (self , group_rank ):
3935
3967
world_size = 4
3936
3968
if self .rank >= world_size :
3937
3969
return
@@ -3940,18 +3972,27 @@ def test_scatter_subgroup(self):
3940
3972
x = torch .empty ((10 ,), device = device )
3941
3973
expected = torch .ones ((10 ,), device = device ) * self .rank
3942
3974
if self .rank == 0 or self .rank == 2 :
3943
- c10d .scatter (x , scatter_list = None , src = self .rank + 1 , group = subgroup )
3975
+ if group_rank :
3976
+ c10d .scatter (x , scatter_list = None , group_src = 1 , group = subgroup )
3977
+ else :
3978
+ c10d .scatter (x , scatter_list = None , src = self .rank + 1 , group = subgroup )
3944
3979
else :
3945
3980
scatter_list = [
3946
3981
torch .ones ((10 ,), device = device ) * (self .rank - 1 ),
3947
3982
torch .ones ((10 ,), device = device ) * self .rank ,
3948
3983
]
3949
- c10d .scatter (x , scatter_list = scatter_list , src = self .rank , group = subgroup )
3984
+ if group_rank :
3985
+ c10d .scatter (x , scatter_list = scatter_list , group_src = 1 , group = subgroup )
3986
+ else :
3987
+ c10d .scatter (
3988
+ x , scatter_list = scatter_list , src = self .rank , group = subgroup
3989
+ )
3950
3990
self .assertEqual (x , expected )
3951
3991
3952
3992
@requires_nccl ()
3953
3993
@skip_if_lt_x_gpu (4 )
3954
- def test_scatter_object_list_subgroup (self ):
3994
+ @parametrize ("group_rank" , [True , False ])
3995
+ def test_scatter_object_list_subgroup (self , group_rank ):
3955
3996
world_size = 4
3956
3997
if self .rank >= world_size :
3957
3998
return
@@ -3960,24 +4001,40 @@ def test_scatter_object_list_subgroup(self):
3960
4001
scatter_object_output_list = [None ]
3961
4002
expected = [{"rank" : self .rank }]
3962
4003
if self .rank == 0 or self .rank == 2 :
3963
- c10d .scatter_object_list (
3964
- scatter_object_output_list = scatter_object_output_list ,
3965
- scatter_object_input_list = None ,
3966
- src = self .rank + 1 ,
3967
- group = subgroup ,
3968
- )
4004
+ if group_rank :
4005
+ c10d .scatter_object_list (
4006
+ scatter_object_output_list = scatter_object_output_list ,
4007
+ scatter_object_input_list = None ,
4008
+ group_src = 1 ,
4009
+ group = subgroup ,
4010
+ )
4011
+ else :
4012
+ c10d .scatter_object_list (
4013
+ scatter_object_output_list = scatter_object_output_list ,
4014
+ scatter_object_input_list = None ,
4015
+ src = self .rank + 1 ,
4016
+ group = subgroup ,
4017
+ )
3969
4018
3970
4019
else :
3971
4020
scatter_object_input_list = [
3972
4021
{"rank" : self .rank - 1 },
3973
4022
{"rank" : self .rank },
3974
4023
]
3975
- c10d .scatter_object_list (
3976
- scatter_object_output_list = scatter_object_output_list ,
3977
- scatter_object_input_list = scatter_object_input_list ,
3978
- src = self .rank ,
3979
- group = subgroup ,
3980
- )
4024
+ if group_rank :
4025
+ c10d .scatter_object_list (
4026
+ scatter_object_output_list = scatter_object_output_list ,
4027
+ scatter_object_input_list = scatter_object_input_list ,
4028
+ group_src = 1 ,
4029
+ group = subgroup ,
4030
+ )
4031
+ else :
4032
+ c10d .scatter_object_list (
4033
+ scatter_object_output_list = scatter_object_output_list ,
4034
+ scatter_object_input_list = scatter_object_input_list ,
4035
+ src = self .rank ,
4036
+ group = subgroup ,
4037
+ )
3981
4038
self .assertEqual (scatter_object_output_list , expected )
3982
4039
3983
4040
0 commit comments