@@ -163,6 +163,7 @@ def __call__(
163
163
indices_dtype : torch .dtype = torch .int64 ,
164
164
offsets_dtype : torch .dtype = torch .int64 ,
165
165
lengths_dtype : torch .dtype = torch .int64 ,
166
+ random_seed : Optional [int ] = None ,
166
167
) -> Tuple ["ModelInput" , List ["ModelInput" ]]: ...
167
168
168
169
@@ -180,6 +181,7 @@ def __call__(
180
181
indices_dtype : torch .dtype = torch .int64 ,
181
182
offsets_dtype : torch .dtype = torch .int64 ,
182
183
lengths_dtype : torch .dtype = torch .int64 ,
184
+ random_seed : Optional [int ] = None ,
183
185
) -> Tuple ["ModelInput" , List ["ModelInput" ]]: ...
184
186
185
187
@@ -208,8 +210,12 @@ def gen_model_and_input(
208
210
global_constant_batch : bool = False ,
209
211
num_inputs : int = 1 ,
210
212
input_type : str = "kjt" , # "kjt" or "td"
213
+ random_seed : Optional [int ] = None ,
211
214
) -> Tuple [nn .Module , List [Tuple [ModelInput , List [ModelInput ]]]]:
212
- torch .manual_seed (0 )
215
+ if random_seed is not None :
216
+ torch .manual_seed (random_seed )
217
+ else :
218
+ torch .manual_seed (0 )
213
219
if dedup_feature_names :
214
220
model = model_class (
215
221
tables = cast (
@@ -252,6 +258,7 @@ def gen_model_and_input(
252
258
indices_dtype = indices_dtype ,
253
259
offsets_dtype = offsets_dtype ,
254
260
lengths_dtype = lengths_dtype ,
261
+ random_seed = random_seed ,
255
262
)
256
263
)
257
264
elif generate == ModelInput .generate :
@@ -270,6 +277,7 @@ def gen_model_and_input(
270
277
indices_dtype = indices_dtype ,
271
278
offsets_dtype = offsets_dtype ,
272
279
lengths_dtype = lengths_dtype ,
280
+ random_seed = random_seed ,
273
281
)
274
282
)
275
283
else :
@@ -287,6 +295,7 @@ def gen_model_and_input(
287
295
indices_dtype = indices_dtype ,
288
296
offsets_dtype = offsets_dtype ,
289
297
lengths_dtype = lengths_dtype ,
298
+ random_seed = random_seed ,
290
299
)
291
300
)
292
301
return (model , inputs )
@@ -742,6 +751,7 @@ def sharding_single_rank_test_single_process(
742
751
indices_dtype : torch .dtype = torch .int64 ,
743
752
offsets_dtype : torch .dtype = torch .int64 ,
744
753
lengths_dtype : torch .dtype = torch .int64 ,
754
+ random_seed : Optional [int ] = None ,
745
755
) -> None :
746
756
batch_size = random .randint (0 , batch_size ) if allow_zero_batch_size else batch_size
747
757
# Generate model & inputs.
@@ -770,7 +780,9 @@ def sharding_single_rank_test_single_process(
770
780
indices_dtype = indices_dtype ,
771
781
offsets_dtype = offsets_dtype ,
772
782
lengths_dtype = lengths_dtype ,
783
+ random_seed = random_seed ,
773
784
)
785
+
774
786
global_model = global_model .to (device )
775
787
global_input = inputs [0 ][0 ].to (device )
776
788
local_input = inputs [0 ][1 ][rank ].to (device )
@@ -818,6 +830,7 @@ def sharding_single_rank_test_single_process(
818
830
constraints = constraints ,
819
831
)
820
832
plan : ShardingPlan = planner .collective_plan (local_model , sharders , pg )
833
+
821
834
"""
822
835
Simulating multiple nodes on a single node. However, metadata information and
823
836
tensor placement must still be consistent. Here we overwrite this to do so.
@@ -994,6 +1007,7 @@ def sharding_single_rank_test(
994
1007
indices_dtype : torch .dtype = torch .int64 ,
995
1008
offsets_dtype : torch .dtype = torch .int64 ,
996
1009
lengths_dtype : torch .dtype = torch .int64 ,
1010
+ random_seed : Optional [int ] = None ,
997
1011
) -> None :
998
1012
with MultiProcessContext (rank , world_size , backend , local_size ) as ctx :
999
1013
assert ctx .pg is not None
@@ -1027,6 +1041,7 @@ def sharding_single_rank_test(
1027
1041
indices_dtype = indices_dtype ,
1028
1042
offsets_dtype = offsets_dtype ,
1029
1043
lengths_dtype = lengths_dtype ,
1044
+ random_seed = random_seed ,
1030
1045
)
1031
1046
1032
1047
0 commit comments