@@ -58,8 +58,8 @@ def test_efa(
58
58
59
59
_test_shm_transfer_is_enabled (scheduler_commands , remote_command_executor , partition = "efa-enabled" )
60
60
61
- if instance == "p4d.24xlarge" and os != "centos7" :
62
- _test_nccl_benchmarks (remote_command_executor , test_datadir , "openmpi" , scheduler_commands )
61
+ if instance in [ "p4d.24xlarge" , "p5.48xlarge" ] and os != "centos7" :
62
+ _test_nccl_benchmarks (remote_command_executor , test_datadir , "openmpi" , scheduler_commands , instance )
63
63
64
64
assert_no_errors_in_logs (remote_command_executor , scheduler , skip_ice = True )
65
65
@@ -102,7 +102,7 @@ def _test_shm_transfer_is_enabled(scheduler_commands, remote_command_executor, p
102
102
assert_that (result .stdout ).does_not_contain ("SHM transfer will be disabled because of ptrace protection" )
103
103
104
104
105
- def _test_nccl_benchmarks (remote_command_executor , test_datadir , mpi_module , scheduler_commands ):
105
+ def _test_nccl_benchmarks (remote_command_executor , test_datadir , mpi_module , scheduler_commands , instance ):
106
106
logging .info ("Running NCCL benchmarks" )
107
107
remote_command_executor .run_remote_script (
108
108
str (test_datadir / "nccl_benchmarks" / "init_nccl_benchmarks.sh" ), args = [mpi_module ], hide = True , timeout = 600
@@ -139,5 +139,15 @@ def _test_nccl_benchmarks(remote_command_executor, test_datadir, mpi_module, sch
139
139
"cat /shared/nccl_tests.out | grep -E '1073741824\\ s+268435456' | awk '{print $12}'"
140
140
).stdout
141
141
142
- # Expected "in-place busbw" bandwidth with 2 nodes, 8 tasks per node is about 27GB/s
143
- assert_that (float (max_bandwidth )).is_greater_than (26.0 )
142
+ instance_bandwidth_dict = {
143
+ # p4d.24xlarge - Expected "in-place busbw" bandwidth with 2 nodes, 8 tasks per node is about 27GB/s
144
+ "p4d.24xlarge" : 26.0 ,
145
+ # p5.48xlarge - Expected "in-place busbw" bandwidth with 2 nodes, 8 tasks per node is about 250GB/s
146
+ "p5.48xlarge" : 250.0 ,
147
+ }
148
+
149
+ expected_bandwidth = instance_bandwidth_dict .get (instance )
150
+ if expected_bandwidth is None :
151
+ pytest .fail (f"Instance { instance } is not valid for multiple bandwidth tests" )
152
+
153
+ assert_that (float (max_bandwidth )).is_greater_than (expected_bandwidth )
0 commit comments