Skip to content

Commit 79ccd06

Browse files
authored
Added flex shape for opctl config (#256)
2 parents b380572 + dde5ab5 commit 79ccd06

File tree

9 files changed

+215
-19
lines changed

9 files changed

+215
-19
lines changed

ads/model/deployment/model_deployment.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@
2929
from ads.model.common.utils import _is_json_serializable
3030
from ads.model.deployment.common.utils import send_request
3131
from ads.model.deployment.model_deployment_infrastructure import (
32+
DEFAULT_BANDWIDTH_MBPS,
33+
DEFAULT_REPLICA,
34+
DEFAULT_SHAPE_NAME,
35+
DEFAULT_OCPUS,
36+
DEFAULT_MEMORY_IN_GBS,
3237
MODEL_DEPLOYMENT_INFRASTRUCTURE_TYPE,
3338
ModelDeploymentInfrastructure,
3439
)
@@ -64,12 +69,6 @@
6469
MODEL_DEPLOYMENT_TYPE = "modelDeployment"
6570
MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON = "TRITON"
6671

67-
MODEL_DEPLOYMENT_INSTANCE_SHAPE = "VM.Standard.E4.Flex"
68-
MODEL_DEPLOYMENT_INSTANCE_OCPUS = 1
69-
MODEL_DEPLOYMENT_INSTANCE_MEMORY_IN_GBS = 16
70-
MODEL_DEPLOYMENT_INSTANCE_COUNT = 1
71-
MODEL_DEPLOYMENT_BANDWIDTH_MBPS = 10
72-
7372
MODEL_DEPLOYMENT_RUNTIMES = {
7473
ModelDeploymentRuntimeType.CONDA: ModelDeploymentCondaRuntime,
7574
ModelDeploymentRuntimeType.CONTAINER: ModelDeploymentContainerRuntime,
@@ -1601,7 +1600,7 @@ def _build_model_deployment_configuration_details(self) -> Dict:
16011600

16021601
instance_configuration = {
16031602
infrastructure.CONST_INSTANCE_SHAPE_NAME: infrastructure.shape_name
1604-
or MODEL_DEPLOYMENT_INSTANCE_SHAPE,
1603+
or DEFAULT_SHAPE_NAME,
16051604
}
16061605

16071606
if instance_configuration[infrastructure.CONST_INSTANCE_SHAPE_NAME].endswith(
@@ -1613,14 +1612,14 @@ def _build_model_deployment_configuration_details(self) -> Dict:
16131612
infrastructure.CONST_OCPUS: infrastructure.shape_config_details.get(
16141613
"ocpus", None
16151614
)
1616-
or MODEL_DEPLOYMENT_INSTANCE_OCPUS,
1615+
or DEFAULT_OCPUS,
16171616
infrastructure.CONST_MEMORY_IN_GBS: infrastructure.shape_config_details.get(
16181617
"memory_in_gbs", None
16191618
)
16201619
or infrastructure.shape_config_details.get(
16211620
"memoryInGBs", None
16221621
)
1623-
or MODEL_DEPLOYMENT_INSTANCE_MEMORY_IN_GBS,
1622+
or DEFAULT_MEMORY_IN_GBS,
16241623
}
16251624

16261625
if infrastructure.subnet_id:
@@ -1629,7 +1628,7 @@ def _build_model_deployment_configuration_details(self) -> Dict:
16291628
scaling_policy = {
16301629
infrastructure.CONST_POLICY_TYPE: "FIXED_SIZE",
16311630
infrastructure.CONST_INSTANCE_COUNT: infrastructure.replica
1632-
or MODEL_DEPLOYMENT_INSTANCE_COUNT,
1631+
or DEFAULT_REPLICA,
16331632
}
16341633

16351634
if not runtime.model_uri:
@@ -1660,7 +1659,7 @@ def _build_model_deployment_configuration_details(self) -> Dict:
16601659

16611660
model_configuration_details = {
16621661
infrastructure.CONST_BANDWIDTH_MBPS: infrastructure.bandwidth_mbps
1663-
or MODEL_DEPLOYMENT_BANDWIDTH_MBPS,
1662+
or DEFAULT_BANDWIDTH_MBPS,
16641663
infrastructure.CONST_INSTANCE_CONFIG: instance_configuration,
16651664
runtime.CONST_MODEL_ID: model_id,
16661665
infrastructure.CONST_SCALING_POLICY: scaling_policy,

ads/model/deployment/model_deployment_infrastructure.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
DEFAULT_BANDWIDTH_MBPS = 10
2323
DEFAULT_WEB_CONCURRENCY = 10
2424
DEFAULT_REPLICA = 1
25-
DEFAULT_SHAPE_NAME = "VM.Standard.E2.4"
25+
DEFAULT_SHAPE_NAME = "VM.Standard.E4.Flex"
26+
DEFAULT_OCPUS = 1
27+
DEFAULT_MEMORY_IN_GBS = 16
2628

2729
logger = logging.getLogger(__name__)
2830

@@ -625,4 +627,8 @@ def init(self) -> "ModelDeploymentInfrastructure":
625627
.with_web_concurrency(self.web_concurrency or DEFAULT_WEB_CONCURRENCY)
626628
.with_replica(self.replica or DEFAULT_REPLICA)
627629
.with_shape_name(self.shape_name or DEFAULT_SHAPE_NAME)
630+
.with_shape_config_details(
631+
ocpus=self.shape_config_details.get(self.CONST_OCPUS, DEFAULT_OCPUS),
632+
memory_in_gbs=self.shape_config_details.get(self.CONST_MEMORY_IN_GBS, DEFAULT_MEMORY_IN_GBS)
633+
)
628634
)

ads/opctl/cmds.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,8 @@ def configure() -> None:
593593
("log_id", ""),
594594
("docker_registry", ""),
595595
("conda_pack_os_prefix", "in the format oci://<bucket>@<namespace>/<path>"),
596+
("memory_in_gbs", ""),
597+
("ocpus", "")
596598
]
597599
_set_service_configurations(
598600
ADS_JOBS_CONFIG_FILE_NAME,
@@ -619,6 +621,10 @@ def configure() -> None:
619621
("num_executors", ""),
620622
("spark_version", ""),
621623
("archive_bucket", "in the format oci://<bucket>@<namespace>/<path>"),
624+
("driver_shape_memory_in_gbs", ""),
625+
("driver_shape_ocpus", ""),
626+
("executor_shape_memory_in_gbs", ""),
627+
("executor_shape_ocpus", "")
622628
]
623629
_set_service_configurations(
624630
ADS_DATAFLOW_CONFIG_FILE_NAME,
@@ -668,6 +674,8 @@ def configure() -> None:
668674
("bandwidth_mbps", ""),
669675
("replica", ""),
670676
("web_concurrency", ""),
677+
("memory_in_gbs", ""),
678+
("ocpus", "")
671679
]
672680

673681
_set_service_configurations(

ads/opctl/config/merger.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def process(self, **kwargs) -> None:
5959
# 3. fill in values from default files under ~/.ads_ops
6060
self._fill_config_with_defaults(ads_config_path)
6161

62+
self._config_flex_shape_details()
63+
6264
logger.debug(f"Config: {self.config}")
6365
return self
6466

@@ -196,3 +198,47 @@ def _get_service_config(self, oci_profile: str, ads_config_folder: str) -> Dict:
196198
f"{os.path.join(ads_config_folder, config_file)} does not exist. No config loaded."
197199
)
198200
return {}
201+
202+
def _config_flex_shape_details(self):
203+
infrastructure = self.config["infrastructure"]
204+
backend = self.config["execution"].get("backend", None)
205+
if backend == BACKEND_NAME.JOB.value or backend == BACKEND_NAME.MODEL_DEPLOYMENT.value:
206+
shape_name = infrastructure.get("shape_name", "")
207+
if shape_name.endswith(".Flex"):
208+
if (
209+
"ocpus" not in infrastructure or
210+
"memory_in_gbs" not in infrastructure
211+
):
212+
raise ValueError(
213+
"Parameters `ocpus` and `memory_in_gbs` must be provided for using flex shape. "
214+
"Call `ads opctl config` to specify."
215+
)
216+
infrastructure["shape_config_details"] = {
217+
"ocpus": infrastructure.pop("ocpus"),
218+
"memory_in_gbs": infrastructure.pop("memory_in_gbs")
219+
}
220+
elif backend == BACKEND_NAME.DATAFLOW.value:
221+
executor_shape = infrastructure.get("executor_shape", "")
222+
driver_shape = infrastructure.get("driver_shape", "")
223+
data_flow_shape_config_details = [
224+
"driver_shape_memory_in_gbs",
225+
"driver_shape_ocpus",
226+
"executor_shape_memory_in_gbs",
227+
"executor_shape_ocpus"
228+
]
229+
# executor_shape and driver_shape must be the same shape family
230+
if executor_shape.endswith(".Flex") or driver_shape.endswith(".Flex"):
231+
for parameter in data_flow_shape_config_details:
232+
if parameter not in infrastructure:
233+
raise ValueError(
234+
f"Parameters {parameter} must be provided for using flex shape. "
235+
"Call `ads opctl config` to specify."
236+
)
237+
infrastructure["driver_shape_config"] = {
238+
"ocpus": infrastructure.pop("driver_shape_ocpus"),
239+
"memory_in_gbs": infrastructure.pop("driver_shape_memory_in_gbs")
240+
}
241+
infrastructure["executor_shape_config"] = {
242+
"ocpus": infrastructure.pop("executor_shape_ocpus"),
243+
"memory_in_gbs": infrastructure.pop("executor_shape_memory_in_gbs")
244+
}

tests/unitary/with_extras/opctl/test_files/modeldeployment_conda.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ spec:
1818
logId: ocid1.log.oc1.iad.<unique_id>
1919
projectId: ocid1.datascienceproject.oc1.<unique_id>
2020
replica: 1
21-
shapeName: VM.Standard.E2.4
21+
shapeName: VM.Standard.E4.Flex
22+
shapeConfigDetails:
23+
ocpus: 1
24+
memoryInGBs: 16
2225
webConcurrency: 10
2326
type: datascienceModelDeployment
2427
runtime:

tests/unitary/with_extras/opctl/test_files/modeldeployment_container.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ spec:
1818
logId: ocid1.log.oc1.iad.<unique_id>
1919
projectId: ocid1.datascienceproject.oc1.<unique_id>
2020
replica: 1
21-
shapeName: VM.Standard.E2.4
21+
shapeName: VM.Standard.E4.Flex
22+
shapeConfigDetails:
23+
ocpus: 1
24+
memoryInGBs: 16
2225
webConcurrency: 10
2326
type: datascienceModelDeployment
2427
runtime:

tests/unitary/with_extras/opctl/test_opctl_cmds.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,12 @@ def test_configure(self, confirm, prompt, monkeypatch):
4545
]
4646
+ ["abc"] * 8
4747
+ ["oci://bucket@namespace/path"]
48+
+ ["abc"] * 2
4849
+ ["abc"] * 4
4950
+ ["oci://bucket@namespace/path"]
50-
+ ["abc"] * 3
51+
+ ["abc"] * 7
5152
+ ["abc"] * 4
52-
+ ["abc"] * 8
53+
+ ["abc"] * 10
5354
+ ["1"]
5455
+ ["3"]
5556
)
@@ -115,11 +116,12 @@ def test_configure_in_notebook_session(self, confirm, prompt, monkeypatch):
115116
]
116117
+ ["abc"] * 8
117118
+ ["oci://bucket@namespace/path"]
119+
+ ["abc"] * 2
118120
+ ["abc"] * 4
119121
+ ["oci://bucket@namespace/path"]
120-
+ ["abc"] * 3
122+
+ ["abc"] * 7
121123
+ ["abc"] * 4
122-
+ ["abc"] * 8
124+
+ ["abc"] * 10
123125
+ ["1"]
124126
+ ["3"]
125127
)

tests/unitary/with_extras/opctl/test_opctl_config.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,135 @@ def test_fill_config_with_defaults(self):
169169
},
170170
}
171171

172+
def test_config_flex_shape_details(self):
173+
config_one = {
174+
"execution": {
175+
"backend": "job",
176+
"auth": "api_key",
177+
"oci_config": DEFAULT_OCI_CONFIG_FILE,
178+
"oci_profile": "PROFILE",
179+
"conda_pack_folder": "~/condapack",
180+
"conda_pack_os_prefix": "oci://bucket@namespace/path2",
181+
},
182+
"infrastructure": {
183+
"compartment_id": "oci.compartmentid.abcd",
184+
"project_id": "oci.projectid.abcd",
185+
"shape_name": "VM.Standard.E2.4"
186+
},
187+
}
188+
189+
m = ConfigMerger(config_one)
190+
m._config_flex_shape_details()
191+
192+
assert m.config == {
193+
"execution": {
194+
"backend": "job",
195+
"auth": "api_key",
196+
"oci_config": DEFAULT_OCI_CONFIG_FILE,
197+
"oci_profile": "PROFILE",
198+
"conda_pack_folder": "~/condapack",
199+
"conda_pack_os_prefix": "oci://bucket@namespace/path2",
200+
},
201+
"infrastructure": {
202+
"compartment_id": "oci.compartmentid.abcd",
203+
"project_id": "oci.projectid.abcd",
204+
"shape_name": "VM.Standard.E2.4"
205+
},
206+
}
207+
208+
config_one["infrastructure"]["shape_name"] = "VM.Standard.E3.Flex"
209+
m = ConfigMerger(config_one)
210+
211+
with pytest.raises(
212+
ValueError,
213+
match="Parameters `ocpus` and `memory_in_gbs` must be provided for using flex shape. "
214+
"Call `ads opctl config` to specify."
215+
):
216+
m._config_flex_shape_details()
217+
218+
config_one["infrastructure"]["ocpus"] = 2
219+
config_one["infrastructure"]["memory_in_gbs"] = 24
220+
m = ConfigMerger(config_one)
221+
m._config_flex_shape_details()
222+
223+
assert m.config == {
224+
"execution": {
225+
"backend": "job",
226+
"auth": "api_key",
227+
"oci_config": DEFAULT_OCI_CONFIG_FILE,
228+
"oci_profile": "PROFILE",
229+
"conda_pack_folder": "~/condapack",
230+
"conda_pack_os_prefix": "oci://bucket@namespace/path2",
231+
},
232+
"infrastructure": {
233+
"compartment_id": "oci.compartmentid.abcd",
234+
"project_id": "oci.projectid.abcd",
235+
"shape_name": "VM.Standard.E3.Flex",
236+
"shape_config_details": {
237+
"ocpus": 2,
238+
"memory_in_gbs": 24
239+
}
240+
},
241+
}
242+
243+
config_two = {
244+
"execution": {
245+
"backend": "dataflow",
246+
"auth": "api_key",
247+
"oci_config": DEFAULT_OCI_CONFIG_FILE,
248+
"oci_profile": "PROFILE",
249+
"conda_pack_folder": "~/condapack",
250+
"conda_pack_os_prefix": "oci://bucket@namespace/path2",
251+
},
252+
"infrastructure": {
253+
"compartment_id": "oci.compartmentid.abcd",
254+
"project_id": "oci.projectid.abcd",
255+
"executor_shape": "VM.Standard.E3.Flex",
256+
"driver_shape": "VM.Standard.E3.Flex"
257+
},
258+
}
259+
260+
m = ConfigMerger(config_two)
261+
262+
with pytest.raises(
263+
ValueError,
264+
match="Parameters driver_shape_memory_in_gbs must be provided for using flex shape. "
265+
"Call `ads opctl config` to specify."
266+
):
267+
m._config_flex_shape_details()
268+
269+
270+
config_two["infrastructure"]["driver_shape_memory_in_gbs"] = 36
271+
config_two["infrastructure"]["driver_shape_ocpus"] = 4
272+
config_two["infrastructure"]["executor_shape_memory_in_gbs"] = 48
273+
config_two["infrastructure"]["executor_shape_ocpus"] = 5
274+
275+
m = ConfigMerger(config_two)
276+
m._config_flex_shape_details()
277+
assert m.config == {
278+
"execution": {
279+
"backend": "dataflow",
280+
"auth": "api_key",
281+
"oci_config": DEFAULT_OCI_CONFIG_FILE,
282+
"oci_profile": "PROFILE",
283+
"conda_pack_folder": "~/condapack",
284+
"conda_pack_os_prefix": "oci://bucket@namespace/path2",
285+
},
286+
"infrastructure": {
287+
"compartment_id": "oci.compartmentid.abcd",
288+
"project_id": "oci.projectid.abcd",
289+
"executor_shape": "VM.Standard.E3.Flex",
290+
"executor_shape_config": {
291+
"ocpus": 5,
292+
"memory_in_gbs": 48
293+
},
294+
"driver_shape": "VM.Standard.E3.Flex",
295+
"driver_shape_config": {
296+
"ocpus": 4,
297+
"memory_in_gbs": 36
298+
}
299+
},
300+
}
172301

173302
class TestConfigResolver:
174303
def test_resolve_operator_name(self):

tests/unitary/with_extras/opctl/test_opctl_model_deployment_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def config(self):
4646
"project_id": "ocid1.datascienceproject.oc1.<unique_id>",
4747
"log_group_id": "ocid1.loggroup.oc1.iad.<unique_id>",
4848
"log_id": "ocid1.log.oc1.iad.<unique_id>",
49-
"shape_name": "VM.Standard.E2.4",
49+
"shape_name": "VM.Standard.E4.Flex",
5050
"bandwidth_mbps": 10,
5151
"replica": 1,
5252
"web_concurrency": 10,

0 commit comments

Comments
 (0)