1
- # Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
1
+ # Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
2
#
3
3
# Licensed under the Apache License, Version 2.0 (the "License"). You
4
4
# may not use this file except in compliance with the License. A copy of
19
19
from six .moves .urllib .parse import urlparse
20
20
21
21
from sagemaker_tensorflow_container .training import SAGEMAKER_PARAMETER_SERVER_ENABLED
22
+ from utils import unique_name_from_base
22
23
23
24
24
25
def test_mnist (sagemaker_session , ecr_image , instance_type , framework_version ):
@@ -31,12 +32,11 @@ def test_mnist(sagemaker_session, ecr_image, instance_type, framework_version):
31
32
sagemaker_session = sagemaker_session ,
32
33
image_name = ecr_image ,
33
34
framework_version = framework_version ,
34
- py_version = 'py3' ,
35
- base_job_name = 'test-sagemaker-mnist' )
35
+ script_mode = True )
36
36
inputs = estimator .sagemaker_session .upload_data (
37
37
path = os .path .join (resource_path , 'mnist' , 'data' ),
38
38
key_prefix = 'scriptmode/mnist' )
39
- estimator .fit (inputs )
39
+ estimator .fit (inputs , job_name = unique_name_from_base ( 'test-sagemaker-mnist' ) )
40
40
_assert_s3_file_exists (sagemaker_session .boto_region_name , estimator .model_data )
41
41
42
42
@@ -50,12 +50,11 @@ def test_distributed_mnist_no_ps(sagemaker_session, ecr_image, instance_type, fr
50
50
sagemaker_session = sagemaker_session ,
51
51
image_name = ecr_image ,
52
52
framework_version = framework_version ,
53
- py_version = 'py3' ,
54
- base_job_name = 'test-tf-sm-distributed-mnist' )
53
+ script_mode = True )
55
54
inputs = estimator .sagemaker_session .upload_data (
56
55
path = os .path .join (resource_path , 'mnist' , 'data' ),
57
56
key_prefix = 'scriptmode/mnist' )
58
- estimator .fit (inputs )
57
+ estimator .fit (inputs , job_name = unique_name_from_base ( 'test-tf-sm-distributed-mnist' ) )
59
58
_assert_s3_file_exists (sagemaker_session .boto_region_name , estimator .model_data )
60
59
61
60
@@ -70,12 +69,11 @@ def test_distributed_mnist_ps(sagemaker_session, ecr_image, instance_type, frame
70
69
sagemaker_session = sagemaker_session ,
71
70
image_name = ecr_image ,
72
71
framework_version = framework_version ,
73
- py_version = 'py3' ,
74
- base_job_name = 'test-tf-sm-distributed-mnist' )
72
+ script_mode = True )
75
73
inputs = estimator .sagemaker_session .upload_data (
76
74
path = os .path .join (resource_path , 'mnist' , 'data-distributed' ),
77
75
key_prefix = 'scriptmode/mnist-distributed' )
78
- estimator .fit (inputs )
76
+ estimator .fit (inputs , job_name = unique_name_from_base ( 'test-tf-sm-distributed-mnist' ) )
79
77
_assert_checkpoint_exists (sagemaker_session .boto_region_name , estimator .model_dir , 0 )
80
78
_assert_s3_file_exists (sagemaker_session .boto_region_name , estimator .model_data )
81
79
@@ -104,9 +102,9 @@ def test_s3_plugin(sagemaker_session, ecr_image, instance_type, region, framewor
104
102
sagemaker_session = sagemaker_session ,
105
103
image_name = ecr_image ,
106
104
framework_version = framework_version ,
107
- py_version = 'py3' ,
108
- base_job_name = 'test-tf-sm-s3- mnist')
109
- estimator . fit ( 's3://sagemaker-sample-data-{}/tensorflow/ mnist'. format ( region ))
105
+ script_mode = True )
106
+ estimator . fit ( 's3://sagemaker-sample-data-{}/tensorflow/ mnist'. format ( region ),
107
+ job_name = unique_name_from_base ( 'test-tf-sm-s3- mnist' ))
110
108
_assert_s3_file_exists (region , estimator .model_data )
111
109
_assert_checkpoint_exists (region , estimator .model_dir , 200 )
112
110
0 commit comments