Skip to content

Commit 759e6c3

Browse files
Merge branch 'main' into feature/forecasting-model-deployments
2 parents 66eb667 + f80867a commit 759e6c3

File tree

4 files changed

+167
-76
lines changed

4 files changed

+167
-76
lines changed

ads/aqua/modeldeployment/deployment.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55

66
import json
7+
import re
78
import shlex
89
import threading
910
from datetime import datetime, timedelta
@@ -764,14 +765,16 @@ def _create_deployment(
764765
).deploy(wait_for_completion=False)
765766

766767
deployment_id = deployment.id
768+
767769
logger.info(
768770
f"Aqua model deployment {deployment_id} created for model {aqua_model_id}. Work request Id is {deployment.dsc_model_deployment.workflow_req_id}"
769771
)
772+
status_list = []
770773

771774
progress_thread = threading.Thread(
772775
target=self.get_deployment_status,
773776
args=(
774-
deployment_id,
777+
deployment,
775778
deployment.dsc_model_deployment.workflow_req_id,
776779
model_type,
777780
model_name,
@@ -1277,7 +1280,7 @@ def list_shapes(self, **kwargs) -> List[ComputeShapeSummary]:
12771280

12781281
def get_deployment_status(
12791282
self,
1280-
model_deployment_id: str,
1283+
deployment: ModelDeployment,
12811284
work_request_id: str,
12821285
model_type: str,
12831286
model_name: str,
@@ -1299,37 +1302,60 @@ def get_deployment_status(
12991302
AquaDeployment
13001303
An Aqua deployment instance.
13011304
"""
1302-
ocid = get_ocid_substring(model_deployment_id, key_len=8)
1303-
telemetry_kwargs = {"ocid": ocid}
1304-
1305+
ocid = get_ocid_substring(deployment.id, key_len=8)
13051306
data_science_work_request: DataScienceWorkRequest = DataScienceWorkRequest(
13061307
work_request_id
13071308
)
1308-
13091309
try:
13101310
data_science_work_request.wait_work_request(
13111311
progress_bar_description="Creating model deployment",
13121312
max_wait_time=DEFAULT_WAIT_TIME,
13131313
poll_interval=DEFAULT_POLL_INTERVAL,
13141314
)
13151315
except Exception:
1316+
status = ""
1317+
logs = deployment.show_logs().sort_values(by="time", ascending=False)
1318+
1319+
if logs and len(logs) > 0:
1320+
status = logs.iloc[0]["message"]
1321+
1322+
status = re.sub(r"[^a-zA-Z0-9]", " ", status)
1323+
13161324
if data_science_work_request._error_message:
13171325
error_str = ""
13181326
for error in data_science_work_request._error_message:
13191327
error_str = error_str + " " + error.message
13201328

1321-
self.telemetry.record_event(
1322-
category=f"aqua/{model_type}/deployment/status",
1323-
action="FAILED",
1324-
detail=error_str,
1325-
value=model_name,
1326-
**telemetry_kwargs,
1327-
)
1329+
error_str = re.sub(r"[^a-zA-Z0-9]", " ", error_str)
1330+
telemetry_kwargs = {
1331+
"ocid": ocid,
1332+
"model_name": model_name,
1333+
"work_request_error": error_str,
1334+
"status": status,
1335+
}
1336+
1337+
self.telemetry.record_event(
1338+
category=f"aqua/{model_type}/deployment/status",
1339+
action="FAILED",
1340+
**telemetry_kwargs,
1341+
)
1342+
else:
1343+
telemetry_kwargs = {
1344+
"ocid": ocid,
1345+
"model_name": model_name,
1346+
"status": status,
1347+
}
1348+
1349+
self.telemetry.record_event(
1350+
category=f"aqua/{model_type}/deployment/status",
1351+
action="FAILED",
1352+
**telemetry_kwargs,
1353+
)
13281354

13291355
else:
1330-
self.telemetry.record_event_async(
1356+
telemetry_kwargs = {"ocid": ocid, "model_name": model_name}
1357+
self.telemetry.record_event(
13311358
category=f"aqua/{model_type}/deployment/status",
13321359
action="SUCCEEDED",
1333-
value=model_name,
13341360
**telemetry_kwargs,
13351361
)

ads/common/oci_logging.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,21 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8; -*-
32

43
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76
import datetime
87
import logging
98
import time
10-
from typing import Dict, Union, List
9+
from typing import Dict, List, Union
1110

11+
import oci.exceptions
1212
import oci.logging
1313
import oci.loggingsearch
14-
import oci.exceptions
14+
1515
from ads.common.decorator.utils import class_or_instance_method
1616
from ads.common.oci_mixin import OCIModelMixin, OCIWorkRequestMixin
1717
from ads.common.oci_resource import OCIResource, ResourceNotFoundError
1818

19-
2019
logger = logging.getLogger(__name__)
2120

2221
# Maximum number of log records to be returned by default.
@@ -862,9 +861,7 @@ def tail(
862861
time_start=time_start,
863862
log_filter=log_filter,
864863
)
865-
self._print(
866-
sorted(tail_logs, key=lambda log: log["time"])
867-
)
864+
self._print(sorted(tail_logs, key=lambda log: log["time"]))
868865

869866
def head(
870867
self,

ads/model/deployment/model_deployment.py

Lines changed: 51 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,27 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8; -*-
32

4-
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2021, 2025 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76

87
import collections
98
import copy
109
import datetime
11-
import oci
12-
import warnings
1310
import time
14-
from typing import Dict, List, Union, Any
11+
import warnings
12+
from typing import Any, Dict, List, Union
1513

14+
import oci
1615
import oci.loggingsearch
17-
from ads.common import auth as authutil
1816
import pandas as pd
19-
from ads.model.serde.model_input import JsonModelInputSERDE
17+
from oci.data_science.models import (
18+
CreateModelDeploymentDetails,
19+
LogDetails,
20+
UpdateModelDeploymentDetails,
21+
)
22+
23+
from ads.common import auth as authutil
24+
from ads.common import utils as ads_utils
2025
from ads.common.oci_logging import (
2126
LOG_INTERVAL,
2227
LOG_RECORDS_LIMIT,
@@ -30,10 +35,10 @@
3035
from ads.model.deployment.common.utils import send_request
3136
from ads.model.deployment.model_deployment_infrastructure import (
3237
DEFAULT_BANDWIDTH_MBPS,
38+
DEFAULT_MEMORY_IN_GBS,
39+
DEFAULT_OCPUS,
3340
DEFAULT_REPLICA,
3441
DEFAULT_SHAPE_NAME,
35-
DEFAULT_OCPUS,
36-
DEFAULT_MEMORY_IN_GBS,
3742
MODEL_DEPLOYMENT_INFRASTRUCTURE_TYPE,
3843
ModelDeploymentInfrastructure,
3944
)
@@ -45,18 +50,14 @@
4550
ModelDeploymentRuntimeType,
4651
OCIModelDeploymentRuntimeType,
4752
)
53+
from ads.model.serde.model_input import JsonModelInputSERDE
4854
from ads.model.service.oci_datascience_model_deployment import (
4955
OCIDataScienceModelDeployment,
5056
)
51-
from ads.common import utils as ads_utils
57+
5258
from .common import utils
5359
from .common.utils import State
5460
from .model_deployment_properties import ModelDeploymentProperties
55-
from oci.data_science.models import (
56-
LogDetails,
57-
CreateModelDeploymentDetails,
58-
UpdateModelDeploymentDetails,
59-
)
6061

6162
DEFAULT_WAIT_TIME = 1200
6263
DEFAULT_POLL_INTERVAL = 10
@@ -751,6 +752,8 @@ def watch(
751752
log_filter : str, optional
752753
Expression for filtering the logs. This will be the WHERE clause of the query.
753754
Defaults to None.
755+
status_list : List[str], optional
756+
List of status of model deployment. This is used to store list of status from logs.
754757
755758
Returns
756759
-------
@@ -964,7 +967,9 @@ def predict(
964967
except oci.exceptions.ServiceError as ex:
965968
# When bandwidth exceeds the allocated value, TooManyRequests error (429) will be raised by oci backend.
966969
if ex.status == 429:
967-
bandwidth_mbps = self.infrastructure.bandwidth_mbps or DEFAULT_BANDWIDTH_MBPS
970+
bandwidth_mbps = (
971+
self.infrastructure.bandwidth_mbps or DEFAULT_BANDWIDTH_MBPS
972+
)
968973
utils.get_logger().warning(
969974
f"Load balancer bandwidth exceeds the allocated {bandwidth_mbps} Mbps."
970975
"To estimate the actual bandwidth, use formula: (payload size in KB) * (estimated requests per second) * 8 / 1024."
@@ -1644,22 +1649,22 @@ def _build_model_deployment_configuration_details(self) -> Dict:
16441649
}
16451650

16461651
if infrastructure.subnet_id:
1647-
instance_configuration[
1648-
infrastructure.CONST_SUBNET_ID
1649-
] = infrastructure.subnet_id
1652+
instance_configuration[infrastructure.CONST_SUBNET_ID] = (
1653+
infrastructure.subnet_id
1654+
)
16501655

16511656
if infrastructure.private_endpoint_id:
16521657
if not hasattr(
16531658
oci.data_science.models.InstanceConfiguration, "private_endpoint_id"
16541659
):
16551660
# TODO: add oci version with private endpoint support.
1656-
raise EnvironmentError(
1661+
raise OSError(
16571662
"Private endpoint is not supported in the current OCI SDK installed."
16581663
)
16591664

1660-
instance_configuration[
1661-
infrastructure.CONST_PRIVATE_ENDPOINT_ID
1662-
] = infrastructure.private_endpoint_id
1665+
instance_configuration[infrastructure.CONST_PRIVATE_ENDPOINT_ID] = (
1666+
infrastructure.private_endpoint_id
1667+
)
16631668

16641669
scaling_policy = {
16651670
infrastructure.CONST_POLICY_TYPE: "FIXED_SIZE",
@@ -1704,7 +1709,7 @@ def _build_model_deployment_configuration_details(self) -> Dict:
17041709
oci.data_science.models,
17051710
"ModelDeploymentEnvironmentConfigurationDetails",
17061711
):
1707-
raise EnvironmentError(
1712+
raise OSError(
17081713
"Environment variable hasn't been supported in the current OCI SDK installed."
17091714
)
17101715

@@ -1720,9 +1725,9 @@ def _build_model_deployment_configuration_details(self) -> Dict:
17201725
and runtime.inference_server.upper()
17211726
== MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON
17221727
):
1723-
environment_variables[
1724-
"CONTAINER_TYPE"
1725-
] = MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON
1728+
environment_variables["CONTAINER_TYPE"] = (
1729+
MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON
1730+
)
17261731
runtime.set_spec(runtime.CONST_ENV, environment_variables)
17271732
environment_configuration_details = {
17281733
runtime.CONST_ENVIRONMENT_CONFIG_TYPE: runtime.environment_config_type,
@@ -1734,17 +1739,17 @@ def _build_model_deployment_configuration_details(self) -> Dict:
17341739
oci.data_science.models,
17351740
"OcirModelDeploymentEnvironmentConfigurationDetails",
17361741
):
1737-
raise EnvironmentError(
1742+
raise OSError(
17381743
"Container runtime hasn't been supported in the current OCI SDK installed."
17391744
)
17401745
environment_configuration_details["image"] = runtime.image
17411746
environment_configuration_details["imageDigest"] = runtime.image_digest
17421747
environment_configuration_details["cmd"] = runtime.cmd
17431748
environment_configuration_details["entrypoint"] = runtime.entrypoint
17441749
environment_configuration_details["serverPort"] = runtime.server_port
1745-
environment_configuration_details[
1746-
"healthCheckPort"
1747-
] = runtime.health_check_port
1750+
environment_configuration_details["healthCheckPort"] = (
1751+
runtime.health_check_port
1752+
)
17481753

17491754
model_deployment_configuration_details = {
17501755
infrastructure.CONST_DEPLOYMENT_TYPE: "SINGLE_MODEL",
@@ -1754,7 +1759,7 @@ def _build_model_deployment_configuration_details(self) -> Dict:
17541759

17551760
if runtime.deployment_mode == ModelDeploymentMode.STREAM:
17561761
if not hasattr(oci.data_science.models, "StreamConfigurationDetails"):
1757-
raise EnvironmentError(
1762+
raise OSError(
17581763
"Model deployment mode hasn't been supported in the current OCI SDK installed."
17591764
)
17601765
model_deployment_configuration_details[
@@ -1786,9 +1791,13 @@ def _build_category_log_details(self) -> Dict:
17861791

17871792
logs = {}
17881793
if (
1789-
self.infrastructure.access_log and
1790-
self.infrastructure.access_log.get(self.infrastructure.CONST_LOG_GROUP_ID, None)
1791-
and self.infrastructure.access_log.get(self.infrastructure.CONST_LOG_ID, None)
1794+
self.infrastructure.access_log
1795+
and self.infrastructure.access_log.get(
1796+
self.infrastructure.CONST_LOG_GROUP_ID, None
1797+
)
1798+
and self.infrastructure.access_log.get(
1799+
self.infrastructure.CONST_LOG_ID, None
1800+
)
17921801
):
17931802
logs[self.infrastructure.CONST_ACCESS] = {
17941803
self.infrastructure.CONST_LOG_GROUP_ID: self.infrastructure.access_log.get(
@@ -1799,9 +1808,13 @@ def _build_category_log_details(self) -> Dict:
17991808
),
18001809
}
18011810
if (
1802-
self.infrastructure.predict_log and
1803-
self.infrastructure.predict_log.get(self.infrastructure.CONST_LOG_GROUP_ID, None)
1804-
and self.infrastructure.predict_log.get(self.infrastructure.CONST_LOG_ID, None)
1811+
self.infrastructure.predict_log
1812+
and self.infrastructure.predict_log.get(
1813+
self.infrastructure.CONST_LOG_GROUP_ID, None
1814+
)
1815+
and self.infrastructure.predict_log.get(
1816+
self.infrastructure.CONST_LOG_ID, None
1817+
)
18051818
):
18061819
logs[self.infrastructure.CONST_PREDICT] = {
18071820
self.infrastructure.CONST_LOG_GROUP_ID: self.infrastructure.predict_log.get(

0 commit comments

Comments
 (0)