Skip to content

Commit c29f508

Browse files
committed
Added flex shape for opctl config
1 parent 28474dd commit c29f508

File tree

3 files changed

+183
-0
lines changed

3 files changed

+183
-0
lines changed

ads/opctl/cmds.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,8 @@ def configure() -> None:
593593
("log_id", ""),
594594
("docker_registry", ""),
595595
("conda_pack_os_prefix", "in the format oci://<bucket>@<namespace>/<path>"),
596+
("memory_in_gbs", ""),
597+
("ocpus", "")
596598
]
597599
_set_service_configurations(
598600
ADS_JOBS_CONFIG_FILE_NAME,
@@ -619,6 +621,10 @@ def configure() -> None:
619621
("num_executors", ""),
620622
("spark_version", ""),
621623
("archive_bucket", "in the format oci://<bucket>@<namespace>/<path>"),
624+
("driver_shape_memory_in_gbs", ""),
625+
("driver_shape_ocpus", ""),
626+
("executor_shape_memory_in_gbs", ""),
627+
("executor_shape_ocpus", "")
622628
]
623629
_set_service_configurations(
624630
ADS_DATAFLOW_CONFIG_FILE_NAME,
@@ -668,6 +674,8 @@ def configure() -> None:
668674
("bandwidth_mbps", ""),
669675
("replica", ""),
670676
("web_concurrency", ""),
677+
("memory_in_gbs", ""),
678+
("ocpus", "")
671679
]
672680

673681
_set_service_configurations(

ads/opctl/config/merger.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def process(self, **kwargs) -> None:
5959
# 3. fill in values from default files under ~/.ads_ops
6060
self._fill_config_with_defaults(ads_config_path)
6161

62+
self._config_flex_shape_details()
63+
6264
logger.debug(f"Config: {self.config}")
6365
return self
6466

@@ -196,3 +198,47 @@ def _get_service_config(self, oci_profile: str, ads_config_folder: str) -> Dict:
196198
f"{os.path.join(ads_config_folder, config_file)} does not exist. No config loaded."
197199
)
198200
return {}
201+
202+
def _config_flex_shape_details(self):
203+
infrastructure = self.config["infrastructure"]
204+
backend = self.config["execution"].get("backend", None)
205+
if backend == BACKEND_NAME.JOB.value or backend == BACKEND_NAME.MODEL_DEPLOYMENT.value:
206+
shape_name = infrastructure.get("shape_name", "")
207+
if shape_name.endswith(".Flex"):
208+
if (
209+
"ocpus" not in infrastructure or
210+
"memory_in_gbs" not in infrastructure
211+
):
212+
raise ValueError(
213+
"Parameters `ocpus` and `memory_in_gbs` must be provided for using flex shape. "
214+
"Call `ads opctl config` to specify."
215+
)
216+
infrastructure["shape_config_details"] = {
217+
"ocpus": infrastructure.pop("ocpus"),
218+
"memory_in_gbs": infrastructure.pop("memory_in_gbs")
219+
}
220+
elif backend == BACKEND_NAME.DATAFLOW.value:
221+
executor_shape = infrastructure.get("executor_shape", "")
222+
driver_shape = infrastructure.get("driver_shape", "")
223+
data_flow_shape_config_details = [
224+
"driver_shape_memory_in_gbs",
225+
"driver_shape_ocpus",
226+
"executor_shape_memory_in_gbs",
227+
"executor_shape_ocpus"
228+
]
229+
# executor_shape and driver_shape must be the same shape family
230+
if executor_shape.endswith(".Flex") or driver_shape.endswith(".Flex"):
231+
for parameter in data_flow_shape_config_details:
232+
if parameter not in infrastructure:
233+
raise ValueError(
234+
f"Parameters {parameter} must be provided for using flex shape. "
235+
"Call `ads opctl config` to specify."
236+
)
237+
infrastructure["driver_shape_config"] = {
238+
"ocpus": infrastructure.pop("driver_shape_ocpus"),
239+
"memory_in_gbs": infrastructure.pop("driver_shape_memory_in_gbs")
240+
}
241+
infrastructure["executor_shape_config"] = {
242+
"ocpus": infrastructure.pop("executor_shape_ocpus"),
243+
"memory_in_gbs": infrastructure.pop("executor_shape_memory_in_gbs")
244+
}

tests/unitary/with_extras/opctl/test_opctl_config.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,135 @@ def test_fill_config_with_defaults(self):
169169
},
170170
}
171171

172+
def test_config_flex_shape_details(self):
173+
config_one = {
174+
"execution": {
175+
"backend": "job",
176+
"auth": "api_key",
177+
"oci_config": DEFAULT_OCI_CONFIG_FILE,
178+
"oci_profile": "PROFILE",
179+
"conda_pack_folder": "~/condapack",
180+
"conda_pack_os_prefix": "oci://bucket@namespace/path2",
181+
},
182+
"infrastructure": {
183+
"compartment_id": "oci.compartmentid.abcd",
184+
"project_id": "oci.projectid.abcd",
185+
"shape_name": "VM.Standard.E2.4"
186+
},
187+
}
188+
189+
m = ConfigMerger(config_one)
190+
m._config_flex_shape_details()
191+
192+
assert m.config == {
193+
"execution": {
194+
"backend": "job",
195+
"auth": "api_key",
196+
"oci_config": DEFAULT_OCI_CONFIG_FILE,
197+
"oci_profile": "PROFILE",
198+
"conda_pack_folder": "~/condapack",
199+
"conda_pack_os_prefix": "oci://bucket@namespace/path2",
200+
},
201+
"infrastructure": {
202+
"compartment_id": "oci.compartmentid.abcd",
203+
"project_id": "oci.projectid.abcd",
204+
"shape_name": "VM.Standard.E2.4"
205+
},
206+
}
207+
208+
config_one["infrastructure"]["shape_name"] = "VM.Standard.E3.Flex"
209+
m = ConfigMerger(config_one)
210+
211+
with pytest.raises(
212+
ValueError,
213+
match="Parameters `ocpus` and `memory_in_gbs` must be provided for using flex shape. "
214+
"Call `ads opctl config` to specify."
215+
):
216+
m._config_flex_shape_details()
217+
218+
config_one["infrastructure"]["ocpus"] = 2
219+
config_one["infrastructure"]["memory_in_gbs"] = 24
220+
m = ConfigMerger(config_one)
221+
m._config_flex_shape_details()
222+
223+
assert m.config == {
224+
"execution": {
225+
"backend": "job",
226+
"auth": "api_key",
227+
"oci_config": DEFAULT_OCI_CONFIG_FILE,
228+
"oci_profile": "PROFILE",
229+
"conda_pack_folder": "~/condapack",
230+
"conda_pack_os_prefix": "oci://bucket@namespace/path2",
231+
},
232+
"infrastructure": {
233+
"compartment_id": "oci.compartmentid.abcd",
234+
"project_id": "oci.projectid.abcd",
235+
"shape_name": "VM.Standard.E3.Flex",
236+
"shape_config_details": {
237+
"ocpus": 2,
238+
"memory_in_gbs": 24
239+
}
240+
},
241+
}
242+
243+
config_two = {
244+
"execution": {
245+
"backend": "dataflow",
246+
"auth": "api_key",
247+
"oci_config": DEFAULT_OCI_CONFIG_FILE,
248+
"oci_profile": "PROFILE",
249+
"conda_pack_folder": "~/condapack",
250+
"conda_pack_os_prefix": "oci://bucket@namespace/path2",
251+
},
252+
"infrastructure": {
253+
"compartment_id": "oci.compartmentid.abcd",
254+
"project_id": "oci.projectid.abcd",
255+
"executor_shape": "VM.Standard.E3.Flex",
256+
"driver_shape": "VM.Standard.E3.Flex"
257+
},
258+
}
259+
260+
m = ConfigMerger(config_two)
261+
262+
with pytest.raises(
263+
ValueError,
264+
match="Parameters driver_shape_memory_in_gbs must be provided for using flex shape. "
265+
"Call `ads opctl config` to specify."
266+
):
267+
m._config_flex_shape_details()
268+
269+
270+
config_two["infrastructure"]["driver_shape_memory_in_gbs"] = 36
271+
config_two["infrastructure"]["driver_shape_ocpus"] = 4
272+
config_two["infrastructure"]["executor_shape_memory_in_gbs"] = 48
273+
config_two["infrastructure"]["executor_shape_ocpus"] = 5
274+
275+
m = ConfigMerger(config_two)
276+
m._config_flex_shape_details()
277+
assert m.config == {
278+
"execution": {
279+
"backend": "dataflow",
280+
"auth": "api_key",
281+
"oci_config": DEFAULT_OCI_CONFIG_FILE,
282+
"oci_profile": "PROFILE",
283+
"conda_pack_folder": "~/condapack",
284+
"conda_pack_os_prefix": "oci://bucket@namespace/path2",
285+
},
286+
"infrastructure": {
287+
"compartment_id": "oci.compartmentid.abcd",
288+
"project_id": "oci.projectid.abcd",
289+
"executor_shape": "VM.Standard.E3.Flex",
290+
"executor_shape_config": {
291+
"ocpus": 5,
292+
"memory_in_gbs": 48
293+
},
294+
"driver_shape": "VM.Standard.E3.Flex",
295+
"driver_shape_config": {
296+
"ocpus": 4,
297+
"memory_in_gbs": 36
298+
}
299+
},
300+
}
172301

173302
class TestConfigResolver:
174303
def test_resolve_operator_name(self):

0 commit comments

Comments
 (0)