Skip to content

Commit 40cf5b4

Browse files
chuyang-dengnadiaya
authored andcommitted
fix: tensorflow-2.0 library code changes (#247)
1 parent 57aef72 commit 40cf5b4

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def read_version():
5757
'pandas', 'Pillow', 'h5py'],
5858
extras_require={
5959
'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'pytest-xdist', 'mock',
60-
'sagemaker==1.19.1', 'tensorflow<2.0', 'docker-compose', 'botocore>=1.12.140'],
60+
'sagemaker==1.19.1', 'tensorflow', 'docker-compose', 'botocore>=1.12.140'],
6161
'benchmark': ['click']
6262
},
6363
)

src/sagemaker_tensorflow_container/training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ def _run_ps(env, cluster):
9696
# Force parameter server to run on cpu. Running multiple TensorFlow processes on the same
9797
# GPU is not safe:
9898
# https://stackoverflow.com/questions/46145100/is-it-unsafe-to-run-multiple-tensorflow-processes-on-the-same-gpu
99-
no_gpu_config = tf.ConfigProto(device_count={'GPU': 0})
99+
no_gpu_config = tf.compat.v1.ConfigProto(device_count={'GPU': 0})
100100

101-
server = tf.train.Server(
101+
server = tf.distribute.Server(
102102
cluster_spec, job_name='ps', task_index=task_index, config=no_gpu_config
103103
)
104104

test/unit/test_training.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def test_train_horovod(run_module, single_machine_training_env):
102102
@pytest.mark.skipif(sys.version_info.major != 3,
103103
reason="Skip this for python 2 because of dict key order mismatch")
104104
@patch('tensorflow.train.ClusterSpec')
105-
@patch('tensorflow.train.Server')
105+
@patch('tensorflow.distribute.Server')
106106
@patch('sagemaker_containers.beta.framework.entry_point.run')
107107
@patch('multiprocessing.Process', lambda target: target())
108108
@patch('time.sleep', MagicMock())
@@ -114,7 +114,7 @@ def test_train_distributed_master(run, tf_server, cluster_spec, distributed_trai
114114
'ps': ['host1:2223', 'host2:2223']})
115115

116116
tf_server.assert_called_with(
117-
cluster_spec(), job_name='ps', task_index=0, config=tf.ConfigProto(device_count={'GPU': 0})
117+
cluster_spec(), job_name='ps', task_index=0, config=tf.compat.v1.ConfigProto(device_count={'GPU': 0})
118118
)
119119
tf_server().join.assert_called_with()
120120

@@ -132,7 +132,7 @@ def test_train_distributed_master(run, tf_server, cluster_spec, distributed_trai
132132
@pytest.mark.skipif(sys.version_info.major != 3,
133133
reason="Skip this for python 2 because of dict key order mismatch")
134134
@patch('tensorflow.train.ClusterSpec')
135-
@patch('tensorflow.train.Server')
135+
@patch('tensorflow.distribute.Server')
136136
@patch('sagemaker_containers.beta.framework.entry_point.run')
137137
@patch('multiprocessing.Process', lambda target: target())
138138
@patch('time.sleep', MagicMock())
@@ -146,7 +146,7 @@ def test_train_distributed_worker(run, tf_server, cluster_spec, distributed_trai
146146
'ps': ['host1:2223', 'host2:2223']})
147147

148148
tf_server.assert_called_with(
149-
cluster_spec(), job_name='ps', task_index=1, config=tf.ConfigProto(device_count={'GPU': 0})
149+
cluster_spec(), job_name='ps', task_index=1, config=tf.compat.v1.ConfigProto(device_count={'GPU': 0})
150150
)
151151
tf_server().join.assert_called_with()
152152

0 commit comments

Comments
 (0)