Skip to content

Commit b04c7b8

Browse files
isururanawakafacebook-github-bot
authored andcommitted
Add random_seed for regular model parallel tests to ensure actual randomness in generating embeddings/inputs etc... (#3158)
Summary: Pull Request resolved: #3158 Add random_seed as an optional parameter for gen_model_and_input method that can be used by any other testing methods. Reviewed By: aporialiao Differential Revision: D77742701 fbshipit-source-id: 752c7a9fd84436a5de862413f1851cf09a25d38e
1 parent 140e979 commit b04c7b8

File tree

3 files changed

+27
-3
lines changed

3 files changed

+27
-3
lines changed

torchrec/distributed/test_utils/test_model.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,14 @@ def generate(
100100
indices_dtype: torch.dtype = torch.int64,
101101
offsets_dtype: torch.dtype = torch.int64,
102102
lengths_dtype: torch.dtype = torch.int64,
103+
random_seed: Optional[int] = None,
103104
) -> Tuple["ModelInput", List["ModelInput"]]:
104105
"""
105106
Returns a global (single-rank training) batch
106107
and a list of local (multi-rank training) batches of world_size.
107108
"""
109+
if random_seed is not None:
110+
torch.manual_seed(random_seed)
108111
batch_size_by_rank = [batch_size] * world_size
109112
if variable_batch_size:
110113
batch_size_by_rank = [
@@ -751,9 +754,14 @@ def generate_variable_batch_input(
751754
indices_dtype: torch.dtype = torch.int64,
752755
offsets_dtype: torch.dtype = torch.int64,
753756
lengths_dtype: torch.dtype = torch.int64,
757+
random_seed: Optional[int] = None,
754758
) -> Tuple["ModelInput", List["ModelInput"]]:
755-
torch.manual_seed(100)
756-
random.seed(100)
759+
if random_seed is not None:
760+
torch.manual_seed(random_seed)
761+
random.seed(random_seed)
762+
else:
763+
torch.manual_seed(100)
764+
random.seed(100)
757765
dedup_factor = 2
758766

759767
global_kjt, local_kjts = ModelInput._generate_variable_batch_features(

torchrec/distributed/test_utils/test_sharding.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def __call__(
163163
indices_dtype: torch.dtype = torch.int64,
164164
offsets_dtype: torch.dtype = torch.int64,
165165
lengths_dtype: torch.dtype = torch.int64,
166+
random_seed: Optional[int] = None,
166167
) -> Tuple["ModelInput", List["ModelInput"]]: ...
167168

168169

@@ -180,6 +181,7 @@ def __call__(
180181
indices_dtype: torch.dtype = torch.int64,
181182
offsets_dtype: torch.dtype = torch.int64,
182183
lengths_dtype: torch.dtype = torch.int64,
184+
random_seed: Optional[int] = None,
183185
) -> Tuple["ModelInput", List["ModelInput"]]: ...
184186

185187

@@ -208,8 +210,12 @@ def gen_model_and_input(
208210
global_constant_batch: bool = False,
209211
num_inputs: int = 1,
210212
input_type: str = "kjt", # "kjt" or "td"
213+
random_seed: Optional[int] = None,
211214
) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]:
212-
torch.manual_seed(0)
215+
if random_seed is not None:
216+
torch.manual_seed(random_seed)
217+
else:
218+
torch.manual_seed(0)
213219
if dedup_feature_names:
214220
model = model_class(
215221
tables=cast(
@@ -252,6 +258,7 @@ def gen_model_and_input(
252258
indices_dtype=indices_dtype,
253259
offsets_dtype=offsets_dtype,
254260
lengths_dtype=lengths_dtype,
261+
random_seed=random_seed,
255262
)
256263
)
257264
elif generate == ModelInput.generate:
@@ -270,6 +277,7 @@ def gen_model_and_input(
270277
indices_dtype=indices_dtype,
271278
offsets_dtype=offsets_dtype,
272279
lengths_dtype=lengths_dtype,
280+
random_seed=random_seed,
273281
)
274282
)
275283
else:
@@ -287,6 +295,7 @@ def gen_model_and_input(
287295
indices_dtype=indices_dtype,
288296
offsets_dtype=offsets_dtype,
289297
lengths_dtype=lengths_dtype,
298+
random_seed=random_seed,
290299
)
291300
)
292301
return (model, inputs)
@@ -742,6 +751,7 @@ def sharding_single_rank_test_single_process(
742751
indices_dtype: torch.dtype = torch.int64,
743752
offsets_dtype: torch.dtype = torch.int64,
744753
lengths_dtype: torch.dtype = torch.int64,
754+
random_seed: Optional[int] = None,
745755
) -> None:
746756
batch_size = random.randint(0, batch_size) if allow_zero_batch_size else batch_size
747757
# Generate model & inputs.
@@ -770,7 +780,9 @@ def sharding_single_rank_test_single_process(
770780
indices_dtype=indices_dtype,
771781
offsets_dtype=offsets_dtype,
772782
lengths_dtype=lengths_dtype,
783+
random_seed=random_seed,
773784
)
785+
774786
global_model = global_model.to(device)
775787
global_input = inputs[0][0].to(device)
776788
local_input = inputs[0][1][rank].to(device)
@@ -818,6 +830,7 @@ def sharding_single_rank_test_single_process(
818830
constraints=constraints,
819831
)
820832
plan: ShardingPlan = planner.collective_plan(local_model, sharders, pg)
833+
821834
"""
822835
Simulating multiple nodes on a single node. However, metadata information and
823836
tensor placement must still be consistent. Here we overwrite this to do so.
@@ -994,6 +1007,7 @@ def sharding_single_rank_test(
9941007
indices_dtype: torch.dtype = torch.int64,
9951008
offsets_dtype: torch.dtype = torch.int64,
9961009
lengths_dtype: torch.dtype = torch.int64,
1010+
random_seed: Optional[int] = None,
9971011
) -> None:
9981012
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
9991013
assert ctx.pg is not None
@@ -1027,6 +1041,7 @@ def sharding_single_rank_test(
10271041
indices_dtype=indices_dtype,
10281042
offsets_dtype=offsets_dtype,
10291043
lengths_dtype=lengths_dtype,
1044+
random_seed=random_seed,
10301045
)
10311046

10321047

torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,7 @@ def test_ssd_mixed_kernels_with_vbe(
609609
},
610610
constraints=constraints,
611611
variable_batch_per_feature=True,
612+
random_seed=100,
612613
)
613614

614615
@unittest.skipIf(

0 commit comments

Comments
 (0)