|
19 | 19 | from sagemaker.tensorflow import TensorFlow
|
20 | 20 | from sagemaker.utils import unique_name_from_base
|
21 | 21 |
|
22 |
| -RESOURCE_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'resources') |
| 22 | +RESOURCE_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "resources") |
23 | 23 |
|
24 | 24 |
|
25 | 25 | @pytest.mark.skip_generic
|
26 |
| -def test_distributed_training_horovod(sagemaker_session, |
27 |
| - instance_type, |
28 |
| - image_uri, |
29 |
| - tmpdir, |
30 |
| - framework_version): |
| 26 | +def test_distributed_training_horovod( |
| 27 | + sagemaker_session, instance_type, image_uri, tmpdir, framework_version |
| 28 | +): |
31 | 29 |
|
32 |
| - mpi_options = '-verbose -x orte_base_help_aggregate=0' |
| 30 | + mpi_options = "-verbose -x orte_base_help_aggregate=0" |
33 | 31 | estimator = TensorFlow(
|
34 |
| - entry_point=os.path.join(RESOURCE_PATH, 'mnist', 'horovod_mnist.py'), |
35 |
| - role='SageMakerRole', |
| 32 | + entry_point=os.path.join(RESOURCE_PATH, "mnist", "horovod_mnist.py"), |
| 33 | + role="SageMakerRole", |
36 | 34 | train_instance_type=instance_type,
|
37 | 35 | train_instance_count=2,
|
38 | 36 | image_name=image_uri,
|
39 | 37 | framework_version=framework_version,
|
40 |
| - py_version='py3', |
| 38 | + py_version="py3", |
41 | 39 | script_mode=True,
|
42 |
| - hyperparameters={'sagemaker_mpi_enabled': True, |
43 |
| - 'sagemaker_mpi_custom_mpi_options': mpi_options, |
44 |
| - 'sagemaker_mpi_num_of_processes_per_host': 1}, |
45 |
| - sagemaker_session=sagemaker_session) |
| 40 | + hyperparameters={ |
| 41 | + "sagemaker_mpi_enabled": True, |
| 42 | + "sagemaker_mpi_custom_mpi_options": mpi_options, |
| 43 | + "sagemaker_mpi_num_of_processes_per_host": 1, |
| 44 | + }, |
| 45 | + sagemaker_session=sagemaker_session, |
| 46 | + ) |
46 | 47 |
|
47 |
| - estimator.fit(job_name=unique_name_from_base('test-tf-horovod')) |
| 48 | + estimator.fit(job_name=unique_name_from_base("test-tf-horovod")) |
48 | 49 |
|
49 | 50 | model_data_source = sagemaker.local.data.get_data_source_instance(
|
50 |
| - estimator.model_data, sagemaker_session) |
| 51 | + estimator.model_data, sagemaker_session |
| 52 | + ) |
51 | 53 |
|
52 | 54 | for filename in model_data_source.get_file_list():
|
53 |
| - assert os.path.basename(filename) == 'model.tar.gz' |
| 55 | + assert os.path.basename(filename) == "model.tar.gz" |
| 56 | + |
| 57 | + |
| 58 | +@pytest.mark.skip_generic |
| 59 | +def test_distributed_training_horovod_with_env_vars( |
| 60 | + sagemaker_session, instance_type, image_uri, tmpdir, framework_version |
| 61 | +): |
| 62 | + |
| 63 | + mpi_options = "-verbose -x orte_base_help_aggregate=0" |
| 64 | + estimator = TensorFlow( |
| 65 | + entry_point=os.path.join(RESOURCE_PATH, "hvdbasic", "train_hvd_env_vars.py"), |
| 66 | + role="SageMakerRole", |
| 67 | + train_instance_type=instance_type, |
| 68 | + train_instance_count=2, |
| 69 | + image_name=image_uri, |
| 70 | + framework_version=framework_version, |
| 71 | + py_version="py3", |
| 72 | + script_mode=True, |
| 73 | + hyperparameters={ |
| 74 | + "sagemaker_mpi_enabled": True, |
| 75 | + "sagemaker_mpi_custom_mpi_options": mpi_options, |
| 76 | + "sagemaker_mpi_num_of_processes_per_host": 2, |
| 77 | + }, |
| 78 | + sagemaker_session=sagemaker_session, |
| 79 | + ) |
| 80 | + |
| 81 | + estimator.fit(job_name=unique_name_from_base("test-tf-horovod-env-vars")) |
0 commit comments