@@ -117,29 +117,37 @@ def _get_instance_type_parameters(): # noqa: C901
117
117
for page in paginator .paginate (LocationType = "availability-zone" ):
118
118
for instance_type in page ["InstanceTypeOfferings" ]:
119
119
# Check if instance type ends with '.xlarge'
120
- if instance_type ["InstanceType" ].endswith (".xlarge" ) and not any (
121
- instance_type [ "InstanceType" ]. startswith ( prefix ) for prefix in excluded_instance_type_prefixes
120
+ if instance_type ["InstanceType" ].endswith (".xlarge" ) and _is_current_instance_type_generation (
121
+ excluded_instance_type_prefixes , instance_type
122
122
):
123
123
xlarge_instances .append (instance_type ["InstanceType" ])
124
- if instance_type_availability_zones .get (instance_type ["InstanceType" ]):
125
- instance_type_availability_zones [instance_type ["InstanceType" ]].append (
126
- instance_type ["Location" ]
127
- )
128
- else :
129
- instance_type_availability_zones [instance_type ["InstanceType" ]] = [
130
- instance_type ["Location" ]
131
- ]
124
+ if instance_type_availability_zones .get (instance_type ["InstanceType" ]):
125
+ instance_type_availability_zones [instance_type ["InstanceType" ]].append (
126
+ instance_type ["Location" ]
127
+ )
128
+ else :
129
+ instance_type_availability_zones [instance_type ["InstanceType" ]] = [instance_type ["Location" ]]
132
130
133
131
xlarge_instances = list (set (xlarge_instances )) # Remove redundancy.
134
132
gpu_instances = []
135
133
paginator = ec2_client .get_paginator ("describe_instance_types" )
136
134
for page in paginator .paginate (InstanceTypes = xlarge_instances ):
137
135
for instance_type in page ["InstanceTypes" ]:
138
- if instance_type .get ("GpuInfo" ):
139
- if (
140
- instance_type .get ("GpuInfo" ).get ("Gpus" )
141
- and instance_type .get ("GpuInfo" ).get ("Gpus" )[0 ].get ("Manufacturer" ) == "NVIDIA"
142
- ):
136
+ if _is_nvidia_gpu_instance_type (instance_type ):
137
+ gpu_instances .append (instance_type ["InstanceType" ])
138
+
139
+ for page in paginator .paginate ():
140
+ for instance_type in page ["InstanceTypes" ]:
141
+ if (
142
+ _is_nvidia_gpu_instance_type (instance_type )
143
+ and instance_type .get ("GpuInfo" ).get ("Gpus" )[0 ].get ("Count" ) >= 4
144
+ and _is_current_instance_type_generation (excluded_instance_type_prefixes , instance_type )
145
+ ):
146
+ # Find instance types with 4 or more GPUs. Number of GPUs can change test behavior.
147
+ # For example, it takes longer for DCGM health check to diagnose multiple GPUs.
148
+ instance_size = instance_type ["InstanceType" ].split ("." )[1 ][: - len ("xlarge" )]
149
+ if instance_size and int (instance_size ) < 20 :
150
+ # Avoid using very expensive instance types
143
151
gpu_instances .append (instance_type ["InstanceType" ])
144
152
145
153
xlarge_instances .sort ()
@@ -154,7 +162,7 @@ def _get_instance_type_parameters(): # noqa: C901
154
162
)
155
163
for index in range (len (gpu_instances )):
156
164
instance_type = gpu_instances [(today_number + index ) % len (gpu_instances )]
157
- result [f"{ region_jinja } _GPU_INSTANCE_TYPE_{ index } " ] = instance_type [: - len ( ".xlarge" )]
165
+ result [f"{ region_jinja } _GPU_INSTANCE_TYPE_{ index } " ] = instance_type
158
166
availability_zones = instance_type_availability_zones [instance_type ]
159
167
result [f"{ region_jinja } _GPU_INSTANCE_TYPE_{ index } _AZ" ] = (
160
168
availability_zones [0 ] if len (availability_zones ) <= 2 else region
@@ -165,11 +173,23 @@ def _get_instance_type_parameters(): # noqa: C901
165
173
result [f"{ region_jinja } _INSTANCE_TYPE_{ index } " ] = "c5"
166
174
result [f"{ region_jinja } _INSTANCE_TYPE_{ index } _AZ" ] = region
167
175
for index in range (10 ):
168
- result [f"{ region_jinja } _GPU_INSTANCE_TYPE_{ index } " ] = "g4dn"
176
+ result [f"{ region_jinja } _GPU_INSTANCE_TYPE_{ index } " ] = "g4dn.xlarge "
169
177
result [f"{ region_jinja } _GPU_INSTANCE_TYPE_{ index } _AZ" ] = region
170
178
return result
171
179
172
180
181
+ def _is_nvidia_gpu_instance_type (instance_type ):
182
+ return (
183
+ instance_type .get ("GpuInfo" )
184
+ and instance_type .get ("GpuInfo" ).get ("Gpus" )
185
+ and instance_type .get ("GpuInfo" ).get ("Gpus" )[0 ].get ("Manufacturer" ) == "NVIDIA"
186
+ )
187
+
188
+
189
+ def _is_current_instance_type_generation (excluded_instance_type_prefixes , instance_type ):
190
+ return not any (instance_type ["InstanceType" ].startswith (prefix ) for prefix in excluded_instance_type_prefixes )
191
+
192
+
173
193
def _get_available_amis_oss (architecture , args = None , config = None ):
174
194
"""
175
195
Gets available AMIs for given architecture from input.
@@ -306,10 +326,16 @@ def _check_or_create_capacity_reservations(config_file, os_parameters, instance_
306
326
307
327
def _resolve_instance_type_and_os (instance_type , instance_type_parameters , os , os_parameters ):
308
328
if "INSTANCE_TYPE" in instance_type :
329
+ # The value of the Jinja INSTANCE_TYPE variable can contain a size or not, e.g. trn1.32xlarge vs trn1.
330
+ # When Jinja name is like INSTANCE_TYPE_0_xlarge, the value doesn't contain size
331
+ # When Jinja name is like INSTANCE_TYPE_0, the value contains size.
332
+ # In other words, the size should appear once either in name or value. The code below handles this logic.
309
333
instance_type_size = instance_type .split ("_" )[- 1 ]
310
- instance_type = (
311
- instance_type_parameters .get (instance_type [: - len (instance_type_size ) - 1 ]) + "." + instance_type_size
312
- )
334
+ instance_type_family = instance_type_parameters .get (instance_type [: - len (instance_type_size ) - 1 ])
335
+ if instance_type_family :
336
+ instance_type = instance_type_family + "." + instance_type_size
337
+ else :
338
+ instance_type = instance_type_parameters .get (instance_type )
313
339
else :
314
340
instance_type = instance_type .replace ("_" , "." )
315
341
os_platform = "Linux/UNIX"
0 commit comments