Skip to content

Commit 92eb20a

Browse files
author
Dan
authored
feature: use tensorflow 2.3.1 and add data parallel integ test (#411)
1 parent 0d9f3fa commit 92eb20a

File tree

7 files changed

+167
-7
lines changed

7 files changed

+167
-7
lines changed

buildspec.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ version: 0.2
22

33
env:
44
variables:
5-
FRAMEWORK_VERSION: '2.2.0'
5+
FRAMEWORK_VERSION: '2.3.1'
66
CPU_INSTANCE_TYPE: 'ml.c4.xlarge'
77
GPU_INSTANCE_TYPE: 'ml.p2.xlarge'
88
ECR_REPO: 'sagemaker-test'
@@ -61,23 +61,23 @@ phases:
6161
# run GPU local integration tests
6262
- printf "$SETUP_CMDS" > $SETUP_FILE
6363
# no reason to rebuild the image again since it was already built and pushed to ECR during CPU tests
64-
- generic_cmd="IGNORE_COVERAGE=- tox -e py36 -- test/integration/local --region $AWS_DEFAULT_REGION --docker-base-name $ECR_REPO --account-id $ACCOUNT --framework-version $FRAMEWORK_VERSION --processor gpu --tag $GENERIC_TAG"
64+
- generic_cmd="IGNORE_COVERAGE=- tox -e py36 -- test/integration/local --dockerfile-type tf --region $AWS_DEFAULT_REGION --docker-base-name $ECR_REPO --account-id $ACCOUNT --framework-version $FRAMEWORK_VERSION --processor gpu --tag $GENERIC_TAG"
6565
- test_cmd="remote-test --github-repo $GITHUB_REPO --test-cmd \"$generic_cmd\" --setup-file $SETUP_FILE --pr-number \"$PR_NUM\""
6666
- execute-command-if-has-matching-changes "$test_cmd" "test/" "src/*.py" "setup.py" "setup.cfg" "buildspec.yml"
67-
- dlc_cmd="IGNORE_COVERAGE=- tox -e py36 -- test/integration/local --region $AWS_DEFAULT_REGION --docker-base-name $ECR_REPO --account-id $ACCOUNT --framework-version $FRAMEWORK_VERSION --processor gpu --tag $DLC_GPU_TAG"
67+
- dlc_cmd="IGNORE_COVERAGE=- tox -e py36 -- test/integration/local --dockerfile-type dlc.gpu --region $AWS_DEFAULT_REGION --docker-base-name $ECR_REPO --account-id $ACCOUNT --framework-version $FRAMEWORK_VERSION --processor gpu --tag $DLC_GPU_TAG"
6868
- test_cmd="remote-test --github-repo $GITHUB_REPO --test-cmd \"$dlc_cmd\" --setup-file $SETUP_FILE --pr-number \"$PR_NUM\" --skip-setup"
6969
- execute-command-if-has-matching-changes "$test_cmd" "test/" "src/*.py" "setup.py" "setup.cfg" "buildspec.yml"
7070

7171
# run CPU sagemaker integration tests
72-
- test_cmd="IGNORE_COVERAGE=- tox -e py36 -- test/integration/sagemaker --region $AWS_DEFAULT_REGION --docker-base-name $ECR_REPO --account-id $ACCOUNT --framework-version $FRAMEWORK_VERSION --processor cpu --instance-type $CPU_INSTANCE_TYPE --tag $GENERIC_TAG"
72+
- test_cmd="IGNORE_COVERAGE=- tox -e py36 -- test/integration/sagemaker -n auto --reruns 3 --reruns-delay 15 --dockerfile-type tf --region $AWS_DEFAULT_REGION --docker-base-name $ECR_REPO --account-id $ACCOUNT --framework-version $FRAMEWORK_VERSION --processor cpu --instance-type $CPU_INSTANCE_TYPE --tag $GENERIC_TAG"
7373
- execute-command-if-has-matching-changes "$test_cmd" "test/" "src/*.py" "setup.py" "setup.cfg" "buildspec.yml"
74-
- test_cmd="IGNORE_COVERAGE=- tox -e py36 -- test/integration/sagemaker --region $AWS_DEFAULT_REGION --docker-base-name $ECR_REPO --account-id $ACCOUNT --framework-version $FRAMEWORK_VERSION --processor cpu --instance-type $CPU_INSTANCE_TYPE --tag $DLC_CPU_TAG"
74+
- test_cmd="IGNORE_COVERAGE=- tox -e py36 -- test/integration/sagemaker -n auto --reruns 3 --reruns-delay 15 --dockerfile-type dlc.cpu --region $AWS_DEFAULT_REGION --docker-base-name $ECR_REPO --account-id $ACCOUNT --framework-version $FRAMEWORK_VERSION --processor cpu --instance-type $CPU_INSTANCE_TYPE --tag $DLC_CPU_TAG"
7575
- execute-command-if-has-matching-changes "$test_cmd" "test/" "src/*.py" "setup.py" "setup.cfg" "buildspec.yml"
7676

7777
# run GPU sagemaker integration tests
78-
- test_cmd="IGNORE_COVERAGE=- tox -e py36 -- test/integration/sagemaker --region $AWS_DEFAULT_REGION --docker-base-name $ECR_REPO --account-id $ACCOUNT --framework-version $FRAMEWORK_VERSION --processor gpu --instance-type $GPU_INSTANCE_TYPE --tag $GENERIC_TAG"
78+
- test_cmd="IGNORE_COVERAGE=- tox -e py36 -- test/integration/sagemaker -n auto --reruns 3 --reruns-delay 15 --dockerfile-type tf --region $AWS_DEFAULT_REGION --docker-base-name $ECR_REPO --account-id $ACCOUNT --framework-version $FRAMEWORK_VERSION --processor gpu --instance-type $GPU_INSTANCE_TYPE --tag $GENERIC_TAG"
7979
- execute-command-if-has-matching-changes "$test_cmd" "test/" "src/*.py" "setup.py" "setup.cfg" "buildspec.yml"
80-
- test_cmd="IGNORE_COVERAGE=- tox -e py36 -- test/integration/sagemaker --region $AWS_DEFAULT_REGION --docker-base-name $ECR_REPO --account-id $ACCOUNT --framework-version $FRAMEWORK_VERSION --processor gpu --instance-type $GPU_INSTANCE_TYPE --tag $DLC_GPU_TAG"
80+
- test_cmd="IGNORE_COVERAGE=- tox -e py36 -- test/integration/sagemaker -n auto --reruns 3 --reruns-delay 15 --dockerfile-type dlc.gpu --region $AWS_DEFAULT_REGION --docker-base-name $ECR_REPO --account-id $ACCOUNT --framework-version $FRAMEWORK_VERSION --processor gpu --instance-type $GPU_INSTANCE_TYPE --tag $DLC_GPU_TAG"
8181
- execute-command-if-has-matching-changes "$test_cmd" "test/" "src/*.py" "setup.py" "setup.cfg" "buildspec.yml"
8282
finally:
8383
# shut down remote GPU instance

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def read_version():
3535
"pytest",
3636
"pytest-cov",
3737
"pytest-xdist",
38+
"pytest-rerunfailures",
3839
"mock",
3940
"sagemaker[local]>=2",
4041
"tensorflow<2.4",
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
ARG region
2+
FROM 763104351884.dkr.ecr.$region.amazonaws.com/tensorflow-training:2.3.1-cpu-py37
3+
4+
COPY dist/sagemaker_tensorflow_training-*.tar.gz /sagemaker_tensorflow_training.tar.gz
5+
RUN pip install --upgrade --no-cache-dir /sagemaker_tensorflow_training.tar.gz && \
6+
rm /sagemaker_tensorflow_training.tar.gz
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
ARG region
2+
FROM 763104351884.dkr.ecr.$region.amazonaws.com/tensorflow-training:2.3.1-gpu-py37-cu110-ubuntu18.04
3+
4+
COPY dist/sagemaker_tensorflow_training-*.tar.gz /sagemaker_tensorflow_training.tar.gz
5+
RUN pip install --upgrade --no-cache-dir /sagemaker_tensorflow_training.tar.gz && \
6+
rm /sagemaker_tensorflow_training.tar.gz

test/container/2.3.1/Dockerfile.tf

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
FROM tensorflow/tensorflow:2.3.1-gpu
2+
3+
ENV SAGEMAKER_TRAINING_MODULE sagemaker_tensorflow_container.training:main
4+
5+
COPY dist/sagemaker_tensorflow_training-*.tar.gz /sagemaker_tensorflow_training.tar.gz
6+
RUN pip install --upgrade --no-cache-dir /sagemaker_tensorflow_training.tar.gz && \
7+
rm /sagemaker_tensorflow_training.tar.gz
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
17+
import pytest
18+
import sagemaker
19+
from sagemaker.tensorflow import TensorFlow
20+
from sagemaker.utils import unique_name_from_base
21+
22+
from integration import DEFAULT_TIMEOUT, RESOURCE_PATH
23+
from integration.sagemaker.timeout import timeout
24+
25+
26+
@pytest.mark.skip_cpu
27+
@pytest.mark.skip_generic
28+
@pytest.mark.parametrize(
29+
"instances, instance_type",
30+
[(2, "ml.p3.16xlarge")],
31+
)
32+
def test_smdataparallel_training(instances, instance_type, sagemaker_session, image_uri, framework_version, tmpdir):
33+
default_bucket = sagemaker_session.default_bucket()
34+
output_path = "s3://{}/{}/{}".format(default_bucket, "tensorflow", "smdataparallel")
35+
36+
estimator = TensorFlow(
37+
entry_point=os.path.join(RESOURCE_PATH, "mnist", "smdataparallel_mnist.py"),
38+
role="SageMakerRole",
39+
instance_type=instance_type,
40+
sagemaker_session=sagemaker_session,
41+
instance_count=instances,
42+
image_uri=image_uri,
43+
output_path=output_path,
44+
framework_version=framework_version,
45+
py_version="py3",
46+
distribution={"smdistributed": {"dataparallel": {"enabled": True}}}
47+
)
48+
49+
with timeout(minutes=DEFAULT_TIMEOUT):
50+
estimator.fit(job_name=unique_name_from_base("test-tf-smdataparallel"))
51+
52+
model_data_source = sagemaker.local.data.get_data_source_instance(
53+
estimator.model_data, sagemaker_session
54+
)
55+
56+
for filename in model_data_source.get_file_list():
57+
assert os.path.basename(filename) == "model.tar.gz"
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
import tensorflow as tf
14+
15+
import smdistributed.dataparallel.tensorflow as dist
16+
17+
tf.random.set_seed(42)
18+
19+
dist.init()
20+
21+
gpus = tf.config.experimental.list_physical_devices("GPU")
22+
for gpu in gpus:
23+
tf.config.experimental.set_memory_growth(gpu, True)
24+
if gpus:
25+
tf.config.experimental.set_visible_devices(gpus[dist.local_rank()], "GPU")
26+
27+
(mnist_images, mnist_labels), _ = tf.keras.datasets.mnist.load_data(
28+
path="mnist-%d.npz" % dist.rank()
29+
)
30+
31+
dataset = tf.data.Dataset.from_tensor_slices(
32+
(tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32), tf.cast(mnist_labels, tf.int64))
33+
)
34+
dataset = dataset.repeat().shuffle(10000).batch(128)
35+
36+
mnist_model = tf.keras.Sequential(
37+
[
38+
tf.keras.layers.Conv2D(32, [3, 3], activation="relu"),
39+
tf.keras.layers.Conv2D(64, [3, 3], activation="relu"),
40+
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
41+
tf.keras.layers.Dropout(0.25),
42+
tf.keras.layers.Flatten(),
43+
tf.keras.layers.Dense(128, activation="relu"),
44+
tf.keras.layers.Dropout(0.5),
45+
tf.keras.layers.Dense(10, activation="softmax"),
46+
]
47+
)
48+
loss = tf.losses.SparseCategoricalCrossentropy()
49+
# LR for 8 node run : 0.000125
50+
# LR for single node run : 0.001
51+
opt = tf.optimizers.Adam(0.000125 * dist.size())
52+
53+
checkpoint_dir = "./checkpoints"
54+
checkpoint = tf.train.Checkpoint(model=mnist_model, optimizer=opt)
55+
56+
57+
@tf.function
58+
def training_step(images, labels, first_batch):
59+
with tf.GradientTape() as tape:
60+
probs = mnist_model(images, training=True)
61+
loss_value = loss(labels, probs)
62+
63+
tape = dist.DistributedGradientTape(tape)
64+
65+
grads = tape.gradient(loss_value, mnist_model.trainable_variables)
66+
opt.apply_gradients(zip(grads, mnist_model.trainable_variables))
67+
68+
if first_batch:
69+
dist.broadcast_variables(mnist_model.variables, root_rank=0)
70+
dist.broadcast_variables(opt.variables(), root_rank=0)
71+
72+
loss_value = dist.oob_allreduce(loss_value) # Average the loss across workers
73+
return loss_value
74+
75+
76+
for batch, (images, labels) in enumerate(dataset.take(10000 // dist.size())):
77+
loss_value = training_step(images, labels, batch == 0)
78+
79+
if batch % 50 == 0 and dist.rank() == 0:
80+
print("Step #%d\tLoss: %.6f" % (batch, loss_value))
81+
82+
if dist.rank() == 0:
83+
checkpoint.save(checkpoint_dir)

0 commit comments

Comments
 (0)