Skip to content

Commit 5076641

Browse files
authored
Add sagemaker-experiments (#301)
1 parent 899e703 commit 5076641

File tree

5 files changed

+108
-1
lines changed

5 files changed

+108
-1
lines changed

docker/2.1.0/py3/Dockerfile.cpu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ RUN ${PIP} install --no-cache-dir -U \
109109
awscli \
110110
mpi4py==3.0.3 \
111111
opencv-python==4.2.0.32 \
112+
sagemaker-experiments==0.1.7 \
112113
"sagemaker-tensorflow>=2.1,<2.2" \
113114
# Let's install TensorFlow separately in the end to avoid
114115
# the library version to be overwritten

docker/2.1.0/py3/Dockerfile.gpu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ RUN ${PIP} install --no-cache-dir -U \
153153
awscli \
154154
mpi4py==3.0.3 \
155155
opencv-python==4.2.0.32 \
156+
sagemaker-experiments==0.1.7 \
156157
"sagemaker-tensorflow>=2.1,<2.2" \
157158
# Let's install TensorFlow separately in the end to avoid
158159
# the library version to be overwritten

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def read_version():
6060
'sagemaker==1.50.1', 'tensorflow<2.0', 'docker-compose', 'boto3==1.10.50',
6161
'six==1.13.0', 'python-dateutil>=2.1,<2.8.1', 'botocore==1.13.50',
6262
'requests-mock', 'awscli==1.16.314'],
63-
'benchmark': ['click']
63+
'benchmark': ['click'],
64+
':python_version=="3.6"': ['sagemaker-experiments==0.1.7']
6465
},
6566
)

test/integration/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,10 @@ def docker_image(docker_base_name, tag):
116116
def ecr_image(account_id, docker_base_name, tag, region):
117117
return '{}.dkr.ecr.{}.amazonaws.com/{}:{}'.format(
118118
account_id, region, docker_base_name, tag)
119+
120+
121+
@pytest.fixture(autouse=True)
122+
def skip_py2_containers(request, tag):
123+
if request.node.get_closest_marker('skip_py2_containers'):
124+
if 'py2' in tag:
125+
pytest.skip('Skipping python2 container with tag {}'.format(tag))
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
import time
17+
18+
import pytest
19+
from sagemaker import utils
20+
from sagemaker.tensorflow import TensorFlow
21+
from smexperiments.experiment import Experiment
22+
from smexperiments.trial import Trial
23+
from smexperiments.trial_component import TrialComponent
24+
25+
from test.integration import DEFAULT_TIMEOUT
26+
from test.integration import RESOURCE_PATH
27+
from timeout import timeout
28+
29+
DATA_PATH = os.path.join(RESOURCE_PATH, "mnist")
30+
SCRIPT_PATH = os.path.join(DATA_PATH, "mnist_gluon_basic_hook_demo.py")
31+
32+
33+
@pytest.mark.skip_py2_containers
34+
def test_training(sagemaker_session, ecr_image, instance_type, framework_version):
35+
36+
sm_client = sagemaker_session.sagemaker_client
37+
38+
experiment_name = f"tf-container-integ-test-{int(time.time())}"
39+
40+
experiment = Experiment.create(
41+
experiment_name=experiment_name,
42+
description="Integration test experiment from sagemaker-tf-container",
43+
sagemaker_boto_client=sm_client,
44+
)
45+
46+
trial_name = f"tf-container-integ-test-{int(time.time())}"
47+
trial = Trial.create(
48+
experiment_name=experiment_name, trial_name=trial_name, sagemaker_boto_client=sm_client
49+
)
50+
51+
training_job_name = utils.unique_name_from_base("test-tf-experiments-mnist")
52+
53+
# create a training job and wait for it to complete
54+
with timeout(minutes=DEFAULT_TIMEOUT):
55+
resource_path = os.path.join(os.path.dirname(__file__), "..", "..", "resources")
56+
script = os.path.join(resource_path, "mnist", "mnist.py")
57+
estimator = TensorFlow(
58+
entry_point=script,
59+
role="SageMakerRole",
60+
train_instance_type=instance_type,
61+
train_instance_count=1,
62+
sagemaker_session=sagemaker_session,
63+
image_name=ecr_image,
64+
framework_version=framework_version,
65+
script_mode=True,
66+
)
67+
inputs = estimator.sagemaker_session.upload_data(
68+
path=os.path.join(resource_path, "mnist", "data"), key_prefix="scriptmode/mnist"
69+
)
70+
estimator.fit(inputs, job_name=training_job_name)
71+
72+
training_job = sm_client.describe_training_job(TrainingJobName=training_job_name)
73+
training_job_arn = training_job["TrainingJobArn"]
74+
75+
# verify trial component auto created from the training job
76+
trial_components = list(
77+
TrialComponent.list(source_arn=training_job_arn, sagemaker_boto_client=sm_client)
78+
)
79+
80+
trial_component_summary = trial_components[0]
81+
trial_component = TrialComponent.load(
82+
trial_component_name=trial_component_summary.trial_component_name,
83+
sagemaker_boto_name=sm_client,
84+
)
85+
86+
# associate the trial component with the trial
87+
trial.add_trial_component(trial_component)
88+
89+
# verify association
90+
associated_trial_components = list(trial.list_trial_components())
91+
assert len(associated_trial_components) == 1
92+
93+
# cleanup
94+
trial.remove_trial_component(trial_component_summary.trial_component_name)
95+
trial_component.delete()
96+
trial.delete()
97+
experiment.delete()

0 commit comments

Comments
 (0)