17
17
18
18
from mock import MagicMock , patch
19
19
import pytest
20
- from sagemaker_containers . beta . framework import runner
20
+ from sagemaker_training import runner
21
21
import tensorflow as tf
22
22
23
23
from sagemaker_tensorflow_container import training
@@ -81,30 +81,30 @@ def test_is_host_master():
81
81
assert training ._is_host_master (HOST_LIST , 'somehost' ) is False
82
82
83
83
84
- @patch ('sagemaker_containers.beta.framework .entry_point.run' )
84
+ @patch ('sagemaker_training .entry_point.run' )
85
85
def test_single_machine (run_module , single_machine_training_env ):
86
86
training .train (single_machine_training_env , MODEL_DIR_CMD_LIST )
87
87
run_module .assert_called_with (MODULE_DIR , MODULE_NAME , MODEL_DIR_CMD_LIST ,
88
88
single_machine_training_env .to_env_vars (),
89
- runner = runner .ProcessRunnerType )
89
+ runner_type = runner .ProcessRunnerType )
90
90
91
91
92
- @patch ('sagemaker_containers.beta.framework .entry_point.run' )
92
+ @patch ('sagemaker_training .entry_point.run' )
93
93
def test_train_horovod (run_module , single_machine_training_env ):
94
94
single_machine_training_env .additional_framework_parameters ['sagemaker_mpi_enabled' ] = True
95
95
96
96
training .train (single_machine_training_env , MODEL_DIR_CMD_LIST )
97
97
run_module .assert_called_with (MODULE_DIR , MODULE_NAME , MODEL_DIR_CMD_LIST ,
98
98
single_machine_training_env .to_env_vars (),
99
- runner = runner .MPIRunnerType )
99
+ runner_type = runner .MPIRunnerType )
100
100
101
101
102
102
@pytest .mark .skip_on_pipeline
103
103
@pytest .mark .skipif (sys .version_info .major != 3 ,
104
104
reason = "Skip this for python 2 because of dict key order mismatch" )
105
105
@patch ('tensorflow.train.ClusterSpec' )
106
106
@patch ('tensorflow.distribute.Server' )
107
- @patch ('sagemaker_containers.beta.framework .entry_point.run' )
107
+ @patch ('sagemaker_training .entry_point.run' )
108
108
@patch ('multiprocessing.Process' , lambda target : target ())
109
109
@patch ('time.sleep' , MagicMock ())
110
110
def test_train_distributed_master (run , tf_server , cluster_spec , distributed_training_env ):
@@ -135,7 +135,7 @@ def test_train_distributed_master(run, tf_server, cluster_spec, distributed_trai
135
135
reason = "Skip this for python 2 because of dict key order mismatch" )
136
136
@patch ('tensorflow.train.ClusterSpec' )
137
137
@patch ('tensorflow.distribute.Server' )
138
- @patch ('sagemaker_containers.beta.framework .entry_point.run' )
138
+ @patch ('sagemaker_training .entry_point.run' )
139
139
@patch ('multiprocessing.Process' , lambda target : target ())
140
140
@patch ('time.sleep' , MagicMock ())
141
141
def test_train_distributed_worker (run , tf_server , cluster_spec , distributed_training_env ):
@@ -163,15 +163,15 @@ def test_train_distributed_worker(run, tf_server, cluster_spec, distributed_trai
163
163
{'TF_CONFIG' : tf_config })
164
164
165
165
166
- @patch ('sagemaker_containers.beta.framework .entry_point.run' )
166
+ @patch ('sagemaker_training .entry_point.run' )
167
167
def test_train_distributed_no_ps (run , distributed_training_env ):
168
168
distributed_training_env .additional_framework_parameters [
169
169
training .SAGEMAKER_PARAMETER_SERVER_ENABLED ] = False
170
170
distributed_training_env .current_host = HOST2
171
171
training .train (distributed_training_env , MODEL_DIR_CMD_LIST )
172
172
173
173
run .assert_called_with (MODULE_DIR , MODULE_NAME , MODEL_DIR_CMD_LIST ,
174
- distributed_training_env .to_env_vars (), runner = runner .ProcessRunnerType )
174
+ distributed_training_env .to_env_vars (), runner_type = runner .ProcessRunnerType )
175
175
176
176
177
177
def test_build_tf_config ():
@@ -241,8 +241,8 @@ def test_log_model_missing_warning_correct(logger):
241
241
@patch ('sagemaker_tensorflow_container.training.logger' )
242
242
@patch ('sagemaker_tensorflow_container.training.train' )
243
243
@patch ('logging.Logger.setLevel' )
244
- @patch ('sagemaker_containers.beta.framework.training_env ' )
245
- @patch ('sagemaker_containers.beta.framework.env .read_hyperparameters' , return_value = {})
244
+ @patch ('sagemaker_training.environment.Environment ' )
245
+ @patch ('sagemaker_training.environment .read_hyperparameters' , return_value = {})
246
246
@patch ('sagemaker_tensorflow_container.s3_utils.configure' )
247
247
def test_main (configure_s3_env , read_hyperparameters , training_env ,
248
248
set_level , train , logger , single_machine_training_env ):
@@ -258,8 +258,8 @@ def test_main(configure_s3_env, read_hyperparameters, training_env,
258
258
@patch ('sagemaker_tensorflow_container.training.logger' )
259
259
@patch ('sagemaker_tensorflow_container.training.train' )
260
260
@patch ('logging.Logger.setLevel' )
261
- @patch ('sagemaker_containers.beta.framework.training_env ' )
262
- @patch ('sagemaker_containers.beta.framework.env .read_hyperparameters' , return_value = {'model_dir' : MODEL_DIR })
261
+ @patch ('sagemaker_training.environment.Environment ' )
262
+ @patch ('sagemaker_training.environment .read_hyperparameters' , return_value = {'model_dir' : MODEL_DIR })
263
263
@patch ('sagemaker_tensorflow_container.s3_utils.configure' )
264
264
def test_main_simple_training_model_dir (configure_s3_env , read_hyperparameters , training_env ,
265
265
set_level , train , logger , single_machine_training_env ):
@@ -272,9 +272,9 @@ def test_main_simple_training_model_dir(configure_s3_env, read_hyperparameters,
272
272
@patch ('sagemaker_tensorflow_container.training.logger' )
273
273
@patch ('sagemaker_tensorflow_container.training.train' )
274
274
@patch ('logging.Logger.setLevel' )
275
- @patch ('sagemaker_containers.beta.framework.training_env ' )
276
- @patch ('sagemaker_containers.beta.framework.env .read_hyperparameters' , return_value = {'model_dir' : MODEL_DIR ,
277
- '_tuning_objective_metric' : 'auc' })
275
+ @patch ('sagemaker_training.environment.Environment ' )
276
+ @patch ('sagemaker_training.environment .read_hyperparameters' , return_value = {'model_dir' : MODEL_DIR ,
277
+ '_tuning_objective_metric' : 'auc' })
278
278
@patch ('sagemaker_tensorflow_container.s3_utils.configure' )
279
279
def test_main_tuning_model_dir (configure_s3_env , read_hyperparameters , training_env ,
280
280
set_level , train , logger , single_machine_training_env ):
@@ -288,9 +288,9 @@ def test_main_tuning_model_dir(configure_s3_env, read_hyperparameters, training_
288
288
@patch ('sagemaker_tensorflow_container.training.logger' )
289
289
@patch ('sagemaker_tensorflow_container.training.train' )
290
290
@patch ('logging.Logger.setLevel' )
291
- @patch ('sagemaker_containers.beta.framework.training_env ' )
292
- @patch ('sagemaker_containers.beta.framework.env .read_hyperparameters' , return_value = {'model_dir' : '/opt/ml/model' ,
293
- '_tuning_objective_metric' : 'auc' })
291
+ @patch ('sagemaker_training.environment.Environment ' )
292
+ @patch ('sagemaker_training.environment .read_hyperparameters' , return_value = {'model_dir' : '/opt/ml/model' ,
293
+ '_tuning_objective_metric' : 'auc' })
294
294
@patch ('sagemaker_tensorflow_container.s3_utils.configure' )
295
295
def test_main_tuning_mpi_model_dir (configure_s3_env , read_hyperparameters , training_env ,
296
296
set_level , train , logger , single_machine_training_env ):
0 commit comments