Skip to content

Commit 611b9d4

Browse files
authored
Support for Private Endpoint on Model Deployment and AQUA MD (#979)
2 parents b44e1a7 + 30b8bcf commit 611b9d4

File tree

10 files changed

+80
-2
lines changed

10 files changed

+80
-2
lines changed

ads/aqua/extension/deployment_handler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def post(self, *args, **kwargs):
102102
ocpus = input_data.get("ocpus")
103103
memory_in_gbs = input_data.get("memory_in_gbs")
104104
model_file = input_data.get("model_file")
105+
private_endpoint_id = input_data.get("private_endpoint_id")
105106

106107
self.finish(
107108
AquaDeploymentApp().create(
@@ -124,6 +125,7 @@ def post(self, *args, **kwargs):
124125
ocpus=ocpus,
125126
memory_in_gbs=memory_in_gbs,
126127
model_file=model_file,
128+
private_endpoint_id=private_endpoint_id,
127129
)
128130
)
129131

ads/aqua/modeldeployment/deployment.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def create(
106106
memory_in_gbs: Optional[float] = None,
107107
ocpus: Optional[float] = None,
108108
model_file: Optional[str] = None,
109+
private_endpoint_id: Optional[str] = None,
109110
) -> "AquaDeployment":
110111
"""
111112
Creates a new Aqua deployment
@@ -152,6 +153,9 @@ def create(
152153
The ocpu count for the shape selected.
153154
model_file: str
154155
The file used for model deployment.
156+
private_endpoint_id: str
157+
The private endpoint id of model deployment.
158+
155159
Returns
156160
-------
157161
AquaDeployment
@@ -345,6 +349,7 @@ def create(
345349
.with_bandwidth_mbps(bandwidth_mbps)
346350
.with_replica(instance_count)
347351
.with_web_concurrency(web_concurrency)
352+
.with_private_endpoint_id(private_endpoint_id)
348353
.with_access_log(
349354
log_group_id=log_group_id,
350355
log_id=access_log_id,

ads/aqua/modeldeployment/entities.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55
from dataclasses import dataclass, field
6-
from typing import Union
6+
from typing import Union, Optional
77

88
from oci.data_science.models import (
99
ModelDeployment,
@@ -47,9 +47,10 @@ class AquaDeployment(DataClassSerializable):
4747
created_on: str = None
4848
created_by: str = None
4949
endpoint: str = None
50+
private_endpoint_id: str = None
5051
console_link: str = None
5152
lifecycle_details: str = None
52-
shape_info: field(default_factory=ShapeInfo) = None
53+
shape_info: Optional[ShapeInfo] = None
5354
tags: dict = None
5455
environment_variables: dict = None
5556

@@ -98,6 +99,7 @@ def from_oci_model_deployment(
9899
freeform_tags = oci_model_deployment.freeform_tags or UNKNOWN_DICT
99100
aqua_service_model_tag = freeform_tags.get(Tags.AQUA_SERVICE_MODEL_TAG, None)
100101
aqua_model_name = freeform_tags.get(Tags.AQUA_MODEL_NAME_TAG, UNKNOWN)
102+
private_endpoint_id = getattr(instance_configuration, "private_endpoint_id", UNKNOWN)
101103

102104
return AquaDeployment(
103105
id=oci_model_deployment.id,
@@ -113,6 +115,7 @@ def from_oci_model_deployment(
113115
created_on=str(oci_model_deployment.time_created),
114116
created_by=oci_model_deployment.created_by,
115117
endpoint=oci_model_deployment.model_deployment_url,
118+
private_endpoint_id=private_endpoint_id,
116119
console_link=get_console_link(
117120
resource="model-deployments",
118121
ocid=oci_model_deployment.id,

ads/model/deployment/model_deployment.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1648,6 +1648,19 @@ def _build_model_deployment_configuration_details(self) -> Dict:
16481648
infrastructure.CONST_SUBNET_ID
16491649
] = infrastructure.subnet_id
16501650

1651+
if infrastructure.private_endpoint_id:
1652+
if not hasattr(
1653+
oci.data_science.models.InstanceConfiguration, "private_endpoint_id"
1654+
):
1655+
# TODO: add oci version with private endpoint support.
1656+
raise EnvironmentError(
1657+
"Private endpoint is not supported in the current OCI SDK installed."
1658+
)
1659+
1660+
instance_configuration[
1661+
infrastructure.CONST_PRIVATE_ENDPOINT_ID
1662+
] = infrastructure.private_endpoint_id
1663+
16511664
scaling_policy = {
16521665
infrastructure.CONST_POLICY_TYPE: "FIXED_SIZE",
16531666
infrastructure.CONST_INSTANCE_COUNT: infrastructure.replica

ads/model/deployment/model_deployment_infrastructure.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ class ModelDeploymentInfrastructure(Builder):
5757
The web concurrency of model deployment
5858
subnet_id: str
5959
The subnet id of model deployment
60+
private_endpoint_id: str
61+
The private endpoint id of model deployment
6062
6163
Methods
6264
-------
@@ -84,6 +86,8 @@ class ModelDeploymentInfrastructure(Builder):
8486
Sets the web concurrency of model deployment
8587
with_subnet_id(subnet_id)
8688
Sets the subnet id of model deployment
89+
with_private_endpoint_id(private_endpoint)
90+
Sets the private endpoint id of model deployment
8791
8892
Example
8993
-------
@@ -100,6 +104,7 @@ class ModelDeploymentInfrastructure(Builder):
100104
... .with_bandwidth_mbps(10)
101105
... .with_web_concurrency(10)
102106
... .with_subnet_id(<subnet_id>)
107+
... .with_private_endpoint_id(<private_endpoint_id>)
103108
... .with_access_log(
104109
... log_group_id=<log_group_id>,
105110
... log_id=<log_id>
@@ -143,6 +148,7 @@ class ModelDeploymentInfrastructure(Builder):
143148
CONST_WEB_CONCURRENCY = "webConcurrency"
144149
CONST_STREAM_CONFIG_DETAILS = "streamConfigurationDetails"
145150
CONST_SUBNET_ID = "subnetId"
151+
CONST_PRIVATE_ENDPOINT_ID = "privateEndpointId"
146152

147153
attribute_map = {
148154
CONST_PROJECT_ID: "project_id",
@@ -159,6 +165,7 @@ class ModelDeploymentInfrastructure(Builder):
159165
CONST_LOG_GROUP_ID: "log_group_id",
160166
CONST_WEB_CONCURRENCY: "web_concurrency",
161167
CONST_SUBNET_ID: "subnet_id",
168+
CONST_PRIVATE_ENDPOINT_ID: "private_endpoint_id",
162169
}
163170

164171
shape_config_details_attribute_map = {
@@ -186,6 +193,7 @@ class ModelDeploymentInfrastructure(Builder):
186193
CONST_SHAPE_NAME: f"{MODEL_CONFIG_DETAILS_PATH}.instance_configuration.instance_shape_name",
187194
CONST_SHAPE_CONFIG_DETAILS: f"{MODEL_CONFIG_DETAILS_PATH}.instance_configuration.model_deployment_instance_shape_config_details",
188195
CONST_SUBNET_ID: f"{MODEL_CONFIG_DETAILS_PATH}.instance_configuration.subnet_id",
196+
CONST_PRIVATE_ENDPOINT_ID: f"{MODEL_CONFIG_DETAILS_PATH}.instance_configuration.private_endpoint_id",
189197
CONST_REPLICA: f"{MODEL_CONFIG_DETAILS_PATH}.scaling_policy.instance_count",
190198
CONST_BANDWIDTH_MBPS: f"{MODEL_CONFIG_DETAILS_PATH}.bandwidth_mbps",
191199
CONST_ACCESS_LOG: "category_log_details.access",
@@ -613,6 +621,32 @@ def subnet_id(self) -> str:
613621
"""
614622
return self.get_spec(self.CONST_SUBNET_ID, None)
615623

624+
def with_private_endpoint_id(self, private_endpoint_id: str) -> "ModelDeploymentInfrastructure":
625+
"""Sets the private endpoint id of model deployment.
626+
627+
Parameters
628+
----------
629+
private_endpoint_id : str
630+
The private endpoint id of model deployment.
631+
632+
Returns
633+
-------
634+
ModelDeploymentInfrastructure
635+
The ModelDeploymentInfrastructure instance (self).
636+
"""
637+
return self.set_spec(self.CONST_PRIVATE_ENDPOINT_ID, private_endpoint_id)
638+
639+
@property
640+
def private_endpoint_id(self) -> str:
641+
"""The model deployment private endpoint id.
642+
643+
Returns
644+
-------
645+
str
646+
The model deployment private endpoint id.
647+
"""
648+
return self.get_spec(self.CONST_PRIVATE_ENDPOINT_ID, None)
649+
616650
def init(self, **kwargs) -> "ModelDeploymentInfrastructure":
617651
"""Initializes a starter specification for the ModelDeploymentInfrastructure.
618652

ads/model/generic_model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2262,6 +2262,7 @@ def deploy(
22622262
description: Optional[str] = None,
22632263
deployment_instance_shape: Optional[str] = None,
22642264
deployment_instance_subnet_id: Optional[str] = None,
2265+
deployment_instance_private_endpoint_id: Optional[str] = None,
22652266
deployment_instance_count: Optional[int] = None,
22662267
deployment_bandwidth_mbps: Optional[int] = None,
22672268
deployment_log_group_id: Optional[str] = None,
@@ -2312,6 +2313,8 @@ def deploy(
23122313
The shape of the instance used for deployment.
23132314
deployment_instance_subnet_id: (str, optional). Default to None.
23142315
The subnet id of the instance used for deployment.
2316+
deployment_instance_private_endpoint_id: (str, optional). Default to None.
2317+
The private endpoint id of instance used for deployment.
23152318
deployment_instance_count: (int, optional). Defaults to 1.
23162319
The number of instance used for deployment.
23172320
deployment_bandwidth_mbps: (int, optional). Defaults to 10.
@@ -2432,6 +2435,8 @@ def deploy(
24322435
or self.properties.deployment_image,
24332436
deployment_instance_subnet_id=existing_infrastructure.subnet_id
24342437
or self.properties.deployment_instance_subnet_id,
2438+
deployment_instance_private_endpoint_id=existing_infrastructure.private_endpoint_id
2439+
or self.properties.deployment_instance_private_endpoint_id,
24352440
).to_dict()
24362441

24372442
property_dict.update(override_properties)
@@ -2465,6 +2470,7 @@ def deploy(
24652470
.with_shape_name(self.properties.deployment_instance_shape)
24662471
.with_replica(self.properties.deployment_instance_count)
24672472
.with_subnet_id(self.properties.deployment_instance_subnet_id)
2473+
.with_private_endpoint_id(self.properties.deployment_instance_private_endpoint_id)
24682474
)
24692475

24702476
web_concurrency = (
@@ -2611,6 +2617,7 @@ def prepare_save_deploy(
26112617
deployment_description: Optional[str] = None,
26122618
deployment_instance_shape: Optional[str] = None,
26132619
deployment_instance_subnet_id: Optional[str] = None,
2620+
deployment_instance_private_endpoint_id: Optional[str] = None,
26142621
deployment_instance_count: Optional[int] = None,
26152622
deployment_bandwidth_mbps: Optional[int] = None,
26162623
deployment_log_group_id: Optional[str] = None,
@@ -2701,6 +2708,8 @@ def prepare_save_deploy(
27012708
The shape of the instance used for deployment.
27022709
deployment_instance_subnet_id: (str, optional). Default to None.
27032710
The subnet id of the instance used for deployment.
2711+
deployment_instance_private_endpoint_id: (str, optional). Default to None.
2712+
The private endpoint id of instance used for deployment.
27042713
deployment_instance_count: (int, optional). Defaults to 1.
27052714
The number of instance used for deployment.
27062715
deployment_bandwidth_mbps: (int, optional). Defaults to 10.
@@ -2846,6 +2855,7 @@ def prepare_save_deploy(
28462855
description=deployment_description,
28472856
deployment_instance_shape=self.properties.deployment_instance_shape,
28482857
deployment_instance_subnet_id=self.properties.deployment_instance_subnet_id,
2858+
deployment_instance_private_endpoint_id=self.properties.deployment_instance_private_endpoint_id,
28492859
deployment_instance_count=self.properties.deployment_instance_count,
28502860
deployment_bandwidth_mbps=self.properties.deployment_bandwidth_mbps,
28512861
deployment_log_group_id=self.properties.deployment_log_group_id,

ads/model/model_properties.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class ModelProperties(BaseProperties):
2929
overwrite_existing_artifact: bool = None
3030
deployment_instance_shape: str = None
3131
deployment_instance_subnet_id: str = None
32+
deployment_instance_private_endpoint_id: str = None
3233
deployment_instance_count: int = None
3334
deployment_bandwidth_mbps: int = None
3435
deployment_log_group_id: str = None

tests/unitary/with_extras/aqua/test_deployment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ class TestDataset:
181181
"created_on": "2024-01-01T00:00:00.000000+00:00",
182182
"created_by": "ocid1.user.oc1..<OCID>",
183183
"endpoint": MODEL_DEPLOYMENT_URL,
184+
"private_endpoint_id": "",
184185
"environment_variables": {
185186
"BASE_MODEL": "service_models/model-name/artifact",
186187
"MODEL_DEPLOY_ENABLE_STREAMING": "true",

tests/unitary/with_extras/aqua/test_deployment_handler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def test_post(self, mock_create):
129129
memory_in_gbs=None,
130130
ocpus=None,
131131
model_file=None,
132+
private_endpoint_id=None,
132133
)
133134

134135

tests/unitary/with_extras/model/test_generic_model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,6 +1514,7 @@ def test__to_yaml(self):
15141514
"description": None,
15151515
"deployment_instance_shape": None,
15161516
"deployment_instance_subnet_id": None,
1517+
"deployment_instance_private_endpoint_id": None,
15171518
"deployment_instance_count": None,
15181519
"deployment_bandwidth_mbps": None,
15191520
"deployment_log_group_id": None,
@@ -1554,6 +1555,7 @@ def test__to_yaml(self):
15541555
"deployment_description": None,
15551556
"deployment_instance_shape": None,
15561557
"deployment_instance_subnet_id": None,
1558+
"deployment_instance_private_endpoint_id": None,
15571559
"deployment_instance_count": None,
15581560
"deployment_bandwidth_mbps": None,
15591561
"deployment_log_group_id": None,
@@ -1606,6 +1608,7 @@ def test__to_yaml(self):
16061608
"description": None,
16071609
"deployment_instance_shape": None,
16081610
"deployment_instance_subnet_id": None,
1611+
"deployment_instance_private_endpoint_id": None,
16091612
"deployment_instance_count": None,
16101613
"deployment_bandwidth_mbps": None,
16111614
"deployment_log_group_id": None,
@@ -1646,6 +1649,7 @@ def test__to_yaml(self):
16461649
"deployment_description": "fake_deployment_description",
16471650
"deployment_instance_shape": "2.1",
16481651
"deployment_instance_subnet_id": "ocid1.subnet.oc1.iad.<unique_ocid>",
1652+
"deployment_instance_private_endpoint_id": "ocid1.datascienceprivateendpointint.oc1.iad.<unique_ocid>",
16491653
"deployment_instance_count": 1,
16501654
"deployment_bandwidth_mbps": 10,
16511655
"deployment_log_group_id": "ocid1.loggroup.oc1.iad.<unique_ocid>",
@@ -1704,6 +1708,7 @@ def test__to_yaml(self):
17041708
"deployment_instance_count": 1,
17051709
"deployment_bandwidth_mbps": 10,
17061710
"deployment_instance_subnet_id": "ocid1.subnet.oc1.iad.<unique_ocid>",
1711+
"deployment_instance_private_endpoint_id": "ocid1.datascienceprivateendpointint.oc1.iad.<unique_ocid>",
17071712
"deployment_log_group_id": "ocid1.loggroup.oc1.iad.<unique_ocid>",
17081713
"deployment_access_log_id": "ocid1.log.oc1.iad.<unique_ocid>",
17091714
"deployment_predict_log_id": "ocid1.log.oc1.iad.<unique_ocid>",
@@ -1746,6 +1751,7 @@ def test__to_yaml(self):
17461751
"deployment_description": "fake_deployment_description",
17471752
"deployment_instance_shape": "2.1",
17481753
"deployment_instance_subnet_id": "ocid1.subnet.oc1.iad.<unique_ocid>",
1754+
"deployment_instance_private_endpoint_id": "ocid1.datascienceprivateendpointint.oc1.iad.<unique_ocid>",
17491755
"deployment_instance_count": 1,
17501756
"deployment_bandwidth_mbps": 10,
17511757
"deployment_log_group_id": "ocid",
@@ -1810,6 +1816,7 @@ def test__to_yaml(self):
18101816
"deployment_instance_count": 1,
18111817
"deployment_bandwidth_mbps": 10,
18121818
"deployment_instance_subnet_id": "ocid1.subnet.oc1.iad.<unique_ocid>",
1819+
"deployment_instance_private_endpoint_id": "ocid1.datascienceprivateendpointint.oc1.iad.<unique_ocid>",
18131820
"deployment_log_group_id": "ocid",
18141821
"deployment_access_log_id": "ocid",
18151822
"deployment_predict_log_id": "ocid",
@@ -1882,6 +1889,7 @@ def test_prepare_save_deploy_with_default_display_name(
18821889
"description": None,
18831890
"deployment_instance_shape": None,
18841891
"deployment_instance_subnet_id": None,
1892+
"deployment_instance_private_endpoint_id": None,
18851893
"deployment_instance_count": None,
18861894
"deployment_bandwidth_mbps": None,
18871895
"deployment_memory_in_gbs": None,

0 commit comments

Comments
 (0)