@@ -885,11 +885,11 @@ def test_model_parallelism(self):
885885
886886 @require_torch_gpu
887887 def test_sharded_checkpoints (self ):
888+ torch .manual_seed (0 )
888889 config , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
889890 model = self .model_class (** config ).eval ()
890891 model = model .to (torch_device )
891892
892- torch .manual_seed (0 )
893893 base_output = model (** inputs_dict )
894894
895895 model_size = compute_module_sizes (model )["" ]
@@ -909,7 +909,8 @@ def test_sharded_checkpoints(self):
909909 new_model = new_model .to (torch_device )
910910
911911 torch .manual_seed (0 )
912- _ , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
912+ if "generator" in inputs_dict :
913+ _ , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
913914 new_output = new_model (** inputs_dict )
914915
915916 self .assertTrue (torch .allclose (base_output [0 ], new_output [0 ], atol = 1e-5 ))
@@ -942,7 +943,8 @@ def test_sharded_checkpoints_device_map(self):
942943 new_model = new_model .to (torch_device )
943944
944945 torch .manual_seed (0 )
945- _ , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
946+ if "generator" in inputs_dict :
947+ _ , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
946948 new_output = new_model (** inputs_dict )
947949 self .assertTrue (torch .allclose (base_output [0 ], new_output [0 ], atol = 1e-5 ))
948950
0 commit comments