Skip to content

Commit a22e3df

Browse files
authored
infra: add single-instance, multi-process Horovod test for local GPU (#389)
1 parent 60d8c10 commit a22e3df

File tree

1 file changed

+26
-13
lines changed

1 file changed

+26
-13
lines changed

test/integration/local/test_horovod.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,38 @@
2222
RESOURCE_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'resources')
2323

2424

25+
@pytest.mark.skip_cpu
26+
@pytest.mark.skip_generic
27+
def test_distributed_training_horovod_gpu(
28+
sagemaker_local_session, image_uri, tmpdir, framework_version
29+
):
30+
_test_distributed_training_horovod(
31+
1, 2, sagemaker_local_session, image_uri, tmpdir, framework_version, 'local_gpu'
32+
)
33+
34+
2535
@pytest.mark.skip_gpu
2636
@pytest.mark.skip_generic
27-
@pytest.mark.parametrize('instances, processes', [
28-
[1, 2],
29-
(2, 1),
30-
(2, 2),
31-
(5, 2)])
32-
def test_distributed_training_horovod_basic(instances,
33-
processes,
34-
sagemaker_local_session,
35-
image_uri,
36-
tmpdir,
37-
framework_version):
37+
@pytest.mark.parametrize(
38+
'instances, processes', [(1, 2), (2, 1), (2, 2), (5, 2)]
39+
)
40+
def test_distributed_training_horovod_cpu(
41+
instances, processes, sagemaker_local_session, image_uri, tmpdir, framework_version
42+
):
43+
_test_distributed_training_horovod(
44+
instances, processes, sagemaker_local_session, image_uri, tmpdir, framework_version, 'local'
45+
)
46+
47+
48+
def _test_distributed_training_horovod(
49+
instances, processes, session, image_uri, tmpdir, framework_version, instance_type
50+
):
3851
output_path = 'file://%s' % tmpdir
3952
estimator = TensorFlow(
4053
entry_point=os.path.join(RESOURCE_PATH, 'hvdbasic', 'train_hvd_basic.py'),
4154
role='SageMakerRole',
42-
train_instance_type='local',
43-
sagemaker_session=sagemaker_local_session,
55+
train_instance_type=instance_type,
56+
sagemaker_session=session,
4457
train_instance_count=instances,
4558
image_name=image_uri,
4659
output_path=output_path,

0 commit comments

Comments
 (0)