@@ -102,7 +102,7 @@ def test_train_horovod(run_module, single_machine_training_env):
102
102
@pytest .mark .skipif (sys .version_info .major != 3 ,
103
103
reason = "Skip this for python 2 because of dict key order mismatch" )
104
104
@patch ('tensorflow.train.ClusterSpec' )
105
- @patch ('tensorflow.train .Server' )
105
+ @patch ('tensorflow.distribute .Server' )
106
106
@patch ('sagemaker_containers.beta.framework.entry_point.run' )
107
107
@patch ('multiprocessing.Process' , lambda target : target ())
108
108
@patch ('time.sleep' , MagicMock ())
@@ -114,7 +114,7 @@ def test_train_distributed_master(run, tf_server, cluster_spec, distributed_trai
114
114
'ps' : ['host1:2223' , 'host2:2223' ]})
115
115
116
116
tf_server .assert_called_with (
117
- cluster_spec (), job_name = 'ps' , task_index = 0 , config = tf .ConfigProto (device_count = {'GPU' : 0 })
117
+ cluster_spec (), job_name = 'ps' , task_index = 0 , config = tf .compat . v1 . ConfigProto (device_count = {'GPU' : 0 })
118
118
)
119
119
tf_server ().join .assert_called_with ()
120
120
@@ -132,7 +132,7 @@ def test_train_distributed_master(run, tf_server, cluster_spec, distributed_trai
132
132
@pytest .mark .skipif (sys .version_info .major != 3 ,
133
133
reason = "Skip this for python 2 because of dict key order mismatch" )
134
134
@patch ('tensorflow.train.ClusterSpec' )
135
- @patch ('tensorflow.train .Server' )
135
+ @patch ('tensorflow.distribute .Server' )
136
136
@patch ('sagemaker_containers.beta.framework.entry_point.run' )
137
137
@patch ('multiprocessing.Process' , lambda target : target ())
138
138
@patch ('time.sleep' , MagicMock ())
@@ -146,7 +146,7 @@ def test_train_distributed_worker(run, tf_server, cluster_spec, distributed_trai
146
146
'ps' : ['host1:2223' , 'host2:2223' ]})
147
147
148
148
tf_server .assert_called_with (
149
- cluster_spec (), job_name = 'ps' , task_index = 1 , config = tf .ConfigProto (device_count = {'GPU' : 0 })
149
+ cluster_spec (), job_name = 'ps' , task_index = 1 , config = tf .compat . v1 . ConfigProto (device_count = {'GPU' : 0 })
150
150
)
151
151
tf_server ().join .assert_called_with ()
152
152
0 commit comments