Skip to content

Commit b11af98

Browse files
authored
breaking: Replace sagemaker-containers with sagemaker-training (#355)
1 parent 17524cb commit b11af98

File tree

4 files changed

+52
-49
lines changed

4 files changed

+52
-49
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def read_version():
6262
'Programming Language :: Python :: 3.6',
6363
],
6464

65-
install_requires=['sagemaker-containers>=2.6.2', 'numpy', 'scipy', 'sklearn',
65+
install_requires=['sagemaker-training>=3.5.0', 'numpy', 'scipy', 'sklearn',
6666
'pandas', 'Pillow', 'h5py'],
6767
extras_require={
6868
'test': test_dependencies,

src/sagemaker_tensorflow_container/training.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import subprocess
2020
import time
2121

22-
import sagemaker_containers.beta.framework as framework
22+
from sagemaker_training import entry_point, environment, mapping, runner
2323
import tensorflow as tf
2424

2525
from sagemaker_tensorflow_container import s3_utils
@@ -109,7 +109,7 @@ def _run_worker(env, cmd_args, tf_config):
109109
env_vars = env.to_env_vars()
110110
env_vars['TF_CONFIG'] = json.dumps(tf_config)
111111

112-
framework.entry_point.run(env.module_dir, env.user_entry_point, cmd_args, env_vars)
112+
entry_point.run(env.module_dir, env.user_entry_point, cmd_args, env_vars)
113113

114114

115115
def _wait_until_master_is_down(master):
@@ -128,7 +128,7 @@ def train(env, cmd_args):
128128
"""Get training job environment from env and run the training job.
129129
130130
Args:
131-
env (sagemaker_containers.beta.framework.env.TrainingEnv): Instance of TrainingEnv class
131+
env (sagemaker_training.environment.Environment): Instance of Environment class
132132
"""
133133
parameter_server_enabled = env.additional_framework_parameters.get(
134134
SAGEMAKER_PARAMETER_SERVER_ENABLED, False)
@@ -150,12 +150,15 @@ def train(env, cmd_args):
150150
mpi_enabled = env.additional_framework_parameters.get('sagemaker_mpi_enabled')
151151

152152
if mpi_enabled:
153-
runner_type = framework.runner.MPIRunnerType
153+
runner_type = runner.MPIRunnerType
154154
else:
155-
runner_type = framework.runner.ProcessRunnerType
155+
runner_type = runner.ProcessRunnerType
156156

157-
framework.entry_point.run(env.module_dir, env.user_entry_point, cmd_args, env.to_env_vars(),
158-
runner=runner_type)
157+
entry_point.run(env.module_dir,
158+
env.user_entry_point,
159+
cmd_args,
160+
env.to_env_vars(),
161+
runner_type=runner_type)
159162

160163

161164
def _log_model_missing_warning(model_dir):
@@ -195,8 +198,8 @@ def _model_dir_with_training_job(model_dir, job_name):
195198
def main():
196199
"""Training entry point
197200
"""
198-
hyperparameters = framework.env.read_hyperparameters()
199-
env = framework.training_env(hyperparameters=hyperparameters)
201+
hyperparameters = environment.read_hyperparameters()
202+
env = environment.Environment(hyperparameters=hyperparameters)
200203

201204
user_hyperparameters = env.hyperparameters
202205

@@ -208,5 +211,5 @@ def main():
208211
user_hyperparameters['model_dir'] = model_dir
209212

210213
s3_utils.configure(user_hyperparameters.get('model_dir'), os.environ.get('SAGEMAKER_REGION'))
211-
train(env, framework.mapping.to_cmd_args(user_hyperparameters))
214+
train(env, mapping.to_cmd_args(user_hyperparameters))
212215
_log_model_missing_warning(MODEL_DIR)

test-toolkit/unit/test_training.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from mock import MagicMock, patch
1919
import pytest
20-
from sagemaker_containers.beta.framework import runner
20+
from sagemaker_training import runner
2121
import tensorflow as tf
2222

2323
from sagemaker_tensorflow_container import training
@@ -81,30 +81,30 @@ def test_is_host_master():
8181
assert training._is_host_master(HOST_LIST, 'somehost') is False
8282

8383

84-
@patch('sagemaker_containers.beta.framework.entry_point.run')
84+
@patch('sagemaker_training.entry_point.run')
8585
def test_single_machine(run_module, single_machine_training_env):
8686
training.train(single_machine_training_env, MODEL_DIR_CMD_LIST)
8787
run_module.assert_called_with(MODULE_DIR, MODULE_NAME, MODEL_DIR_CMD_LIST,
8888
single_machine_training_env.to_env_vars(),
89-
runner=runner.ProcessRunnerType)
89+
runner_type=runner.ProcessRunnerType)
9090

9191

92-
@patch('sagemaker_containers.beta.framework.entry_point.run')
92+
@patch('sagemaker_training.entry_point.run')
9393
def test_train_horovod(run_module, single_machine_training_env):
9494
single_machine_training_env.additional_framework_parameters['sagemaker_mpi_enabled'] = True
9595

9696
training.train(single_machine_training_env, MODEL_DIR_CMD_LIST)
9797
run_module.assert_called_with(MODULE_DIR, MODULE_NAME, MODEL_DIR_CMD_LIST,
9898
single_machine_training_env.to_env_vars(),
99-
runner=runner.MPIRunnerType)
99+
runner_type=runner.MPIRunnerType)
100100

101101

102102
@pytest.mark.skip_on_pipeline
103103
@pytest.mark.skipif(sys.version_info.major != 3,
104104
reason="Skip this for python 2 because of dict key order mismatch")
105105
@patch('tensorflow.train.ClusterSpec')
106106
@patch('tensorflow.distribute.Server')
107-
@patch('sagemaker_containers.beta.framework.entry_point.run')
107+
@patch('sagemaker_training.entry_point.run')
108108
@patch('multiprocessing.Process', lambda target: target())
109109
@patch('time.sleep', MagicMock())
110110
def test_train_distributed_master(run, tf_server, cluster_spec, distributed_training_env):
@@ -135,7 +135,7 @@ def test_train_distributed_master(run, tf_server, cluster_spec, distributed_trai
135135
reason="Skip this for python 2 because of dict key order mismatch")
136136
@patch('tensorflow.train.ClusterSpec')
137137
@patch('tensorflow.distribute.Server')
138-
@patch('sagemaker_containers.beta.framework.entry_point.run')
138+
@patch('sagemaker_training.entry_point.run')
139139
@patch('multiprocessing.Process', lambda target: target())
140140
@patch('time.sleep', MagicMock())
141141
def test_train_distributed_worker(run, tf_server, cluster_spec, distributed_training_env):
@@ -163,15 +163,15 @@ def test_train_distributed_worker(run, tf_server, cluster_spec, distributed_trai
163163
{'TF_CONFIG': tf_config})
164164

165165

166-
@patch('sagemaker_containers.beta.framework.entry_point.run')
166+
@patch('sagemaker_training.entry_point.run')
167167
def test_train_distributed_no_ps(run, distributed_training_env):
168168
distributed_training_env.additional_framework_parameters[
169169
training.SAGEMAKER_PARAMETER_SERVER_ENABLED] = False
170170
distributed_training_env.current_host = HOST2
171171
training.train(distributed_training_env, MODEL_DIR_CMD_LIST)
172172

173173
run.assert_called_with(MODULE_DIR, MODULE_NAME, MODEL_DIR_CMD_LIST,
174-
distributed_training_env.to_env_vars(), runner=runner.ProcessRunnerType)
174+
distributed_training_env.to_env_vars(), runner_type=runner.ProcessRunnerType)
175175

176176

177177
def test_build_tf_config():
@@ -241,8 +241,8 @@ def test_log_model_missing_warning_correct(logger):
241241
@patch('sagemaker_tensorflow_container.training.logger')
242242
@patch('sagemaker_tensorflow_container.training.train')
243243
@patch('logging.Logger.setLevel')
244-
@patch('sagemaker_containers.beta.framework.training_env')
245-
@patch('sagemaker_containers.beta.framework.env.read_hyperparameters', return_value={})
244+
@patch('sagemaker_training.environment.Environment')
245+
@patch('sagemaker_training.environment.read_hyperparameters', return_value={})
246246
@patch('sagemaker_tensorflow_container.s3_utils.configure')
247247
def test_main(configure_s3_env, read_hyperparameters, training_env,
248248
set_level, train, logger, single_machine_training_env):
@@ -258,8 +258,8 @@ def test_main(configure_s3_env, read_hyperparameters, training_env,
258258
@patch('sagemaker_tensorflow_container.training.logger')
259259
@patch('sagemaker_tensorflow_container.training.train')
260260
@patch('logging.Logger.setLevel')
261-
@patch('sagemaker_containers.beta.framework.training_env')
262-
@patch('sagemaker_containers.beta.framework.env.read_hyperparameters', return_value={'model_dir': MODEL_DIR})
261+
@patch('sagemaker_training.environment.Environment')
262+
@patch('sagemaker_training.environment.read_hyperparameters', return_value={'model_dir': MODEL_DIR})
263263
@patch('sagemaker_tensorflow_container.s3_utils.configure')
264264
def test_main_simple_training_model_dir(configure_s3_env, read_hyperparameters, training_env,
265265
set_level, train, logger, single_machine_training_env):
@@ -272,9 +272,9 @@ def test_main_simple_training_model_dir(configure_s3_env, read_hyperparameters,
272272
@patch('sagemaker_tensorflow_container.training.logger')
273273
@patch('sagemaker_tensorflow_container.training.train')
274274
@patch('logging.Logger.setLevel')
275-
@patch('sagemaker_containers.beta.framework.training_env')
276-
@patch('sagemaker_containers.beta.framework.env.read_hyperparameters', return_value={'model_dir': MODEL_DIR,
277-
'_tuning_objective_metric': 'auc'})
275+
@patch('sagemaker_training.environment.Environment')
276+
@patch('sagemaker_training.environment.read_hyperparameters', return_value={'model_dir': MODEL_DIR,
277+
'_tuning_objective_metric': 'auc'})
278278
@patch('sagemaker_tensorflow_container.s3_utils.configure')
279279
def test_main_tuning_model_dir(configure_s3_env, read_hyperparameters, training_env,
280280
set_level, train, logger, single_machine_training_env):
@@ -288,9 +288,9 @@ def test_main_tuning_model_dir(configure_s3_env, read_hyperparameters, training_
288288
@patch('sagemaker_tensorflow_container.training.logger')
289289
@patch('sagemaker_tensorflow_container.training.train')
290290
@patch('logging.Logger.setLevel')
291-
@patch('sagemaker_containers.beta.framework.training_env')
292-
@patch('sagemaker_containers.beta.framework.env.read_hyperparameters', return_value={'model_dir': '/opt/ml/model',
293-
'_tuning_objective_metric': 'auc'})
291+
@patch('sagemaker_training.environment.Environment')
292+
@patch('sagemaker_training.environment.read_hyperparameters', return_value={'model_dir': '/opt/ml/model',
293+
'_tuning_objective_metric': 'auc'})
294294
@patch('sagemaker_tensorflow_container.s3_utils.configure')
295295
def test_main_tuning_mpi_model_dir(configure_s3_env, read_hyperparameters, training_env,
296296
set_level, train, logger, single_machine_training_env):

test/unit/test_training.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from mock import MagicMock, patch
1919
import pytest
20-
from sagemaker_containers.beta.framework import runner
20+
from sagemaker_training import runner
2121
import tensorflow as tf
2222

2323
from sagemaker_tensorflow_container import training
@@ -81,30 +81,30 @@ def test_is_host_master():
8181
assert training._is_host_master(HOST_LIST, 'somehost') is False
8282

8383

84-
@patch('sagemaker_containers.beta.framework.entry_point.run')
84+
@patch('sagemaker_training.entry_point.run')
8585
def test_single_machine(run_module, single_machine_training_env):
8686
training.train(single_machine_training_env, MODEL_DIR_CMD_LIST)
8787
run_module.assert_called_with(MODULE_DIR, MODULE_NAME, MODEL_DIR_CMD_LIST,
8888
single_machine_training_env.to_env_vars(),
89-
runner=runner.ProcessRunnerType)
89+
runner_type=runner.ProcessRunnerType)
9090

9191

92-
@patch('sagemaker_containers.beta.framework.entry_point.run')
92+
@patch('sagemaker_training.entry_point.run')
9393
def test_train_horovod(run_module, single_machine_training_env):
9494
single_machine_training_env.additional_framework_parameters['sagemaker_mpi_enabled'] = True
9595

9696
training.train(single_machine_training_env, MODEL_DIR_CMD_LIST)
9797
run_module.assert_called_with(MODULE_DIR, MODULE_NAME, MODEL_DIR_CMD_LIST,
9898
single_machine_training_env.to_env_vars(),
99-
runner=runner.MPIRunnerType)
99+
runner_type=runner.MPIRunnerType)
100100

101101

102102
@pytest.mark.skip_on_pipeline
103103
@pytest.mark.skipif(sys.version_info.major != 3,
104104
reason="Skip this for python 2 because of dict key order mismatch")
105105
@patch('tensorflow.train.ClusterSpec')
106106
@patch('tensorflow.distribute.Server')
107-
@patch('sagemaker_containers.beta.framework.entry_point.run')
107+
@patch('sagemaker_training.entry_point.run')
108108
@patch('multiprocessing.Process', lambda target: target())
109109
@patch('time.sleep', MagicMock())
110110
def test_train_distributed_master(run, tf_server, cluster_spec, distributed_training_env):
@@ -135,7 +135,7 @@ def test_train_distributed_master(run, tf_server, cluster_spec, distributed_trai
135135
reason="Skip this for python 2 because of dict key order mismatch")
136136
@patch('tensorflow.train.ClusterSpec')
137137
@patch('tensorflow.distribute.Server')
138-
@patch('sagemaker_containers.beta.framework.entry_point.run')
138+
@patch('sagemaker_training.entry_point.run')
139139
@patch('multiprocessing.Process', lambda target: target())
140140
@patch('time.sleep', MagicMock())
141141
def test_train_distributed_worker(run, tf_server, cluster_spec, distributed_training_env):
@@ -163,15 +163,15 @@ def test_train_distributed_worker(run, tf_server, cluster_spec, distributed_trai
163163
{'TF_CONFIG': tf_config})
164164

165165

166-
@patch('sagemaker_containers.beta.framework.entry_point.run')
166+
@patch('sagemaker_training.entry_point.run')
167167
def test_train_distributed_no_ps(run, distributed_training_env):
168168
distributed_training_env.additional_framework_parameters[
169169
training.SAGEMAKER_PARAMETER_SERVER_ENABLED] = False
170170
distributed_training_env.current_host = HOST2
171171
training.train(distributed_training_env, MODEL_DIR_CMD_LIST)
172172

173173
run.assert_called_with(MODULE_DIR, MODULE_NAME, MODEL_DIR_CMD_LIST,
174-
distributed_training_env.to_env_vars(), runner=runner.ProcessRunnerType)
174+
distributed_training_env.to_env_vars(), runner_type=runner.ProcessRunnerType)
175175

176176

177177
def test_build_tf_config():
@@ -241,8 +241,8 @@ def test_log_model_missing_warning_correct(logger):
241241
@patch('sagemaker_tensorflow_container.training.logger')
242242
@patch('sagemaker_tensorflow_container.training.train')
243243
@patch('logging.Logger.setLevel')
244-
@patch('sagemaker_containers.beta.framework.training_env')
245-
@patch('sagemaker_containers.beta.framework.env.read_hyperparameters', return_value={})
244+
@patch('sagemaker_training.environment.Environment')
245+
@patch('sagemaker_training.environment.read_hyperparameters', return_value={})
246246
@patch('sagemaker_tensorflow_container.s3_utils.configure')
247247
def test_main(configure_s3_env, read_hyperparameters, training_env,
248248
set_level, train, logger, single_machine_training_env):
@@ -258,8 +258,8 @@ def test_main(configure_s3_env, read_hyperparameters, training_env,
258258
@patch('sagemaker_tensorflow_container.training.logger')
259259
@patch('sagemaker_tensorflow_container.training.train')
260260
@patch('logging.Logger.setLevel')
261-
@patch('sagemaker_containers.beta.framework.training_env')
262-
@patch('sagemaker_containers.beta.framework.env.read_hyperparameters', return_value={'model_dir': MODEL_DIR})
261+
@patch('sagemaker_training.environment.Environment')
262+
@patch('sagemaker_training.environment.read_hyperparameters', return_value={'model_dir': MODEL_DIR})
263263
@patch('sagemaker_tensorflow_container.s3_utils.configure')
264264
def test_main_simple_training_model_dir(configure_s3_env, read_hyperparameters, training_env,
265265
set_level, train, logger, single_machine_training_env):
@@ -272,9 +272,9 @@ def test_main_simple_training_model_dir(configure_s3_env, read_hyperparameters,
272272
@patch('sagemaker_tensorflow_container.training.logger')
273273
@patch('sagemaker_tensorflow_container.training.train')
274274
@patch('logging.Logger.setLevel')
275-
@patch('sagemaker_containers.beta.framework.training_env')
276-
@patch('sagemaker_containers.beta.framework.env.read_hyperparameters', return_value={'model_dir': MODEL_DIR,
277-
'_tuning_objective_metric': 'auc'})
275+
@patch('sagemaker_training.environment.Environment')
276+
@patch('sagemaker_training.environment.read_hyperparameters', return_value={'model_dir': MODEL_DIR,
277+
'_tuning_objective_metric': 'auc'})
278278
@patch('sagemaker_tensorflow_container.s3_utils.configure')
279279
def test_main_tuning_model_dir(configure_s3_env, read_hyperparameters, training_env,
280280
set_level, train, logger, single_machine_training_env):
@@ -288,9 +288,9 @@ def test_main_tuning_model_dir(configure_s3_env, read_hyperparameters, training_
288288
@patch('sagemaker_tensorflow_container.training.logger')
289289
@patch('sagemaker_tensorflow_container.training.train')
290290
@patch('logging.Logger.setLevel')
291-
@patch('sagemaker_containers.beta.framework.training_env')
292-
@patch('sagemaker_containers.beta.framework.env.read_hyperparameters', return_value={'model_dir': '/opt/ml/model',
293-
'_tuning_objective_metric': 'auc'})
291+
@patch('sagemaker_training.environment.Environment')
292+
@patch('sagemaker_training.environment.read_hyperparameters', return_value={'model_dir': '/opt/ml/model',
293+
'_tuning_objective_metric': 'auc'})
294294
@patch('sagemaker_tensorflow_container.s3_utils.configure')
295295
def test_main_tuning_mpi_model_dir(configure_s3_env, read_hyperparameters, training_env,
296296
set_level, train, logger, single_machine_training_env):

0 commit comments

Comments
 (0)