Skip to content

Commit 4ed2b38

Browse files
authored
change: copy all tests to test-toolkit folder. (#292)
1 parent 63896dc commit 4ed2b38

38 files changed

+1822
-1
lines changed

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[flake8]
2-
application_import_names = sagemaker_tensorflow_container, test, timeout, utils
2+
application_import_names = sagemaker_tensorflow_container, test, test-toolkit, timeout, utils
33
import-order-style = google

test-toolkit/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
from __future__ import absolute_import

test-toolkit/integration/__init__.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import logging
16+
import os
17+
18+
logging.getLogger('boto3').setLevel(logging.INFO)
19+
logging.getLogger('botocore').setLevel(logging.INFO)
20+
21+
RESOURCE_PATH = os.path.join(os.path.dirname(__file__), '..', 'resources')
22+
23+
# these regions have some p2 and p3 instances, but not enough for automated testing
24+
NO_P2_REGIONS = [
25+
'ca-central-1',
26+
'eu-central-1',
27+
'eu-west-2',
28+
'us-west-1',
29+
'eu-west-3',
30+
'eu-north-1',
31+
'sa-east-1',
32+
'ap-east-1',
33+
'me-south-1'
34+
]
35+
NO_P3_REGIONS = [
36+
'ap-southeast-1',
37+
'ap-southeast-2',
38+
'ap-south-1',
39+
'ca-central-1',
40+
'eu-central-1',
41+
'eu-west-2',
42+
'us-west-1'
43+
'eu-west-3',
44+
'eu-north-1',
45+
'sa-east-1',
46+
'ap-east-1',
47+
'me-south-1'
48+
]

test-toolkit/integration/conftest.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright 2018-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import logging
16+
import os
17+
18+
import boto3
19+
import pytest
20+
from sagemaker import LocalSession, Session
21+
from sagemaker.tensorflow import TensorFlow
22+
23+
from test.integration import NO_P2_REGIONS, NO_P3_REGIONS
24+
25+
logger = logging.getLogger(__name__)
26+
logging.getLogger('boto').setLevel(logging.INFO)
27+
logging.getLogger('botocore').setLevel(logging.INFO)
28+
logging.getLogger('factory.py').setLevel(logging.INFO)
29+
logging.getLogger('auth.py').setLevel(logging.INFO)
30+
logging.getLogger('connectionpool.py').setLevel(logging.INFO)
31+
32+
SCRIPT_PATH = os.path.dirname(os.path.realpath(__file__))
33+
34+
35+
def pytest_addoption(parser):
36+
parser.addoption('--docker-base-name', default='sagemaker-tensorflow-scriptmode')
37+
parser.addoption('--tag', default=None)
38+
parser.addoption('--region', default='us-west-2')
39+
parser.addoption('--framework-version', default=TensorFlow.LATEST_VERSION)
40+
parser.addoption('--processor', default='cpu', choices=['cpu', 'gpu', 'cpu,gpu'])
41+
parser.addoption('--py-version', default='3', choices=['2', '3', '2,3'])
42+
parser.addoption('--account-id', default='142577830533')
43+
parser.addoption('--instance-type', default=None)
44+
45+
46+
def pytest_configure(config):
47+
os.environ['TEST_PY_VERSIONS'] = config.getoption('--py-version')
48+
os.environ['TEST_PROCESSORS'] = config.getoption('--processor')
49+
50+
51+
@pytest.fixture(scope='session')
52+
def docker_base_name(request):
53+
return request.config.getoption('--docker-base-name')
54+
55+
56+
@pytest.fixture(scope='session')
57+
def region(request):
58+
return request.config.getoption('--region')
59+
60+
61+
@pytest.fixture(scope='session')
62+
def framework_version(request):
63+
return request.config.getoption('--framework-version')
64+
65+
66+
@pytest.fixture
67+
def tag(request, framework_version, processor, py_version):
68+
provided_tag = request.config.getoption('--tag')
69+
default_tag = '{}-{}-py{}'.format(framework_version, processor, py_version)
70+
return provided_tag if provided_tag is not None else default_tag
71+
72+
73+
@pytest.fixture(scope='session')
74+
def sagemaker_session(region):
75+
return Session(boto_session=boto3.Session(region_name=region))
76+
77+
78+
@pytest.fixture(scope='session')
79+
def sagemaker_local_session(region):
80+
return LocalSession(boto_session=boto3.Session(region_name=region))
81+
82+
83+
@pytest.fixture(scope='session')
84+
def account_id(request):
85+
return request.config.getoption('--account-id')
86+
87+
88+
@pytest.fixture
89+
def instance_type(request, processor):
90+
provided_instance_type = request.config.getoption('--instance-type')
91+
default_instance_type = 'ml.c4.xlarge' if processor == 'cpu' else 'ml.p2.xlarge'
92+
return provided_instance_type if provided_instance_type is not None else default_instance_type
93+
94+
95+
@pytest.fixture(autouse=True)
96+
def skip_by_device_type(request, processor):
97+
is_gpu = (processor == 'gpu')
98+
if (request.node.get_closest_marker('skip_gpu') and is_gpu) or \
99+
(request.node.get_closest_marker('skip_cpu') and not is_gpu):
100+
pytest.skip('Skipping because running on \'{}\' instance'.format(processor))
101+
102+
103+
@pytest.fixture(autouse=True)
104+
def skip_gpu_instance_restricted_regions(region, instance_type):
105+
if (region in NO_P2_REGIONS and instance_type.startswith('ml.p2')) or \
106+
(region in NO_P3_REGIONS and instance_type.startswith('ml.p3')):
107+
pytest.skip('Skipping GPU test in region {}'.format(region))
108+
109+
110+
@pytest.fixture
111+
def docker_image(docker_base_name, tag):
112+
return '{}:{}'.format(docker_base_name, tag)
113+
114+
115+
@pytest.fixture
116+
def ecr_image(account_id, docker_base_name, tag, region):
117+
return '{}.dkr.ecr.{}.amazonaws.com/{}:{}'.format(
118+
account_id, region, docker_base_name, tag)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import json
16+
import os
17+
import tarfile
18+
19+
import pytest
20+
from sagemaker.tensorflow import TensorFlow
21+
22+
from test.integration.utils import processor, py_version # noqa: F401
23+
24+
RESOURCE_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'resources')
25+
26+
27+
@pytest.mark.skip_gpu
28+
@pytest.mark.parametrize('instances, processes', [
29+
[1, 2],
30+
(2, 1),
31+
(2, 2),
32+
(5, 2)])
33+
def test_distributed_training_horovod_basic(instances,
34+
processes,
35+
sagemaker_local_session,
36+
docker_image,
37+
tmpdir,
38+
framework_version):
39+
output_path = 'file://%s' % tmpdir
40+
estimator = TensorFlow(
41+
entry_point=os.path.join(RESOURCE_PATH, 'hvdbasic', 'train_hvd_basic.py'),
42+
role='SageMakerRole',
43+
train_instance_type='local',
44+
sagemaker_session=sagemaker_local_session,
45+
train_instance_count=instances,
46+
image_name=docker_image,
47+
output_path=output_path,
48+
framework_version=framework_version,
49+
hyperparameters={'sagemaker_mpi_enabled': True,
50+
'sagemaker_network_interface_name': 'eth0',
51+
'sagemaker_mpi_num_of_processes_per_host': processes})
52+
53+
estimator.fit('file://{}'.format(os.path.join(RESOURCE_PATH, 'mnist', 'data-distributed')))
54+
55+
tmp = str(tmpdir)
56+
extract_files(output_path.replace('file://', ''), tmp)
57+
58+
size = instances * processes
59+
60+
for rank in range(size):
61+
local_rank = rank % processes
62+
assert read_json('local-rank-%s-rank-%s' % (local_rank, rank), tmp) == {
63+
'local-rank': local_rank, 'rank': rank, 'size': size}
64+
65+
66+
def read_json(file, tmp):
67+
with open(os.path.join(tmp, file)) as f:
68+
return json.load(f)
69+
70+
71+
def assert_files_exist_in_tar(output_path, files):
72+
if output_path.startswith('file://'):
73+
output_path = output_path[7:]
74+
model_file = os.path.join(output_path, 'model.tar.gz')
75+
with tarfile.open(model_file) as tar:
76+
for f in files:
77+
tar.getmember(f)
78+
79+
80+
def extract_files(output_path, tmpdir):
81+
with tarfile.open(os.path.join(output_path, 'model.tar.gz')) as tar:
82+
tar.extractall(tmpdir)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import logging
16+
import os
17+
18+
import numpy as np
19+
import pytest
20+
from sagemaker.tensorflow import serving, TensorFlow
21+
22+
from test.integration import RESOURCE_PATH
23+
from test.integration.utils import processor, py_version # noqa: F401
24+
25+
26+
logging.basicConfig(level=logging.DEBUG)
27+
28+
29+
@pytest.mark.skip(reason="Serving part fails because of version mismatch.")
30+
def test_keras_training(sagemaker_local_session, docker_image, tmpdir, framework_version):
31+
entry_point = os.path.join(RESOURCE_PATH, 'keras_inception.py')
32+
output_path = 'file://{}'.format(tmpdir)
33+
34+
estimator = TensorFlow(
35+
entry_point=entry_point,
36+
role='SageMakerRole',
37+
train_instance_count=1,
38+
train_instance_type='local',
39+
image_name=docker_image,
40+
sagemaker_session=sagemaker_local_session,
41+
model_dir='/opt/ml/model',
42+
output_path=output_path,
43+
framework_version=framework_version,
44+
py_version='py3')
45+
46+
estimator.fit()
47+
48+
model = serving.Model(model_data=output_path,
49+
role='SageMakerRole',
50+
framework_version=framework_version,
51+
sagemaker_session=sagemaker_local_session)
52+
53+
predictor = model.deploy(initial_instance_count=1, instance_type='local')
54+
55+
assert predictor.predict(np.random.randn(4, 4, 4, 2) * 255)
56+
57+
predictor.delete_endpoint()

0 commit comments

Comments
 (0)