Skip to content

Commit dcd8aba

Browse files
committed
Updated pr.
1 parent 95ac6ae commit dcd8aba

File tree

2 files changed

+50
-121
lines changed

2 files changed

+50
-121
lines changed

ads/model/deployment/model_deployment.py

Lines changed: 50 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import collections
99
import copy
1010
import datetime
11-
import sys
1211
import oci
1312
import warnings
1413
import time
@@ -72,9 +71,6 @@
7271
MODEL_DEPLOYMENT_INSTANCE_COUNT = 1
7372
MODEL_DEPLOYMENT_BANDWIDTH_MBPS = 10
7473

75-
TIME_FRAME = 60
76-
MAXIMUM_PAYLOAD_SIZE = 10 * 1024 * 1024 # bytes
77-
7874
MODEL_DEPLOYMENT_RUNTIMES = {
7975
ModelDeploymentRuntimeType.CONDA: ModelDeploymentCondaRuntime,
8076
ModelDeploymentRuntimeType.CONTAINER: ModelDeploymentContainerRuntime,
@@ -253,10 +249,6 @@ class ModelDeployment(Builder):
253249
CONST_TIME_CREATED: "time_created",
254250
}
255251

256-
count_start_time = 0
257-
request_counter = 0
258-
estimate_request_per_second = 100
259-
260252
initialize_spec_attributes = [
261253
"display_name",
262254
"description",
@@ -915,51 +907,60 @@ def predict(
915907
raise AttributeError(
916908
"`data` and `json_input` are both provided. You can only use one of them."
917909
)
918-
919-
if auto_serialize_data:
920-
data = data or json_input
921-
serialized_data = serializer.serialize(data=data)
922-
self._validate_bandwidth(serialized_data)
923-
return send_request(
924-
data=serialized_data,
925-
endpoint=endpoint,
926-
is_json_payload=_is_json_serializable(serialized_data),
927-
header=header,
928-
)
929910

930-
if json_input is not None:
931-
if not _is_json_serializable(json_input):
932-
raise ValueError(
933-
"`json_input` must be json serializable. "
934-
"Set `auto_serialize_data` to True, or serialize the provided input data first,"
935-
"or using `data` to pass binary data."
911+
try:
912+
if auto_serialize_data:
913+
data = data or json_input
914+
serialized_data = serializer.serialize(data=data)
915+
return send_request(
916+
data=serialized_data,
917+
endpoint=endpoint,
918+
is_json_payload=_is_json_serializable(serialized_data),
919+
header=header,
936920
)
937-
utils.get_logger().warning(
938-
"The `json_input` argument of `predict()` will be deprecated soon. "
939-
"Please use `data` argument. "
940-
)
941-
data = json_input
942921

943-
is_json_payload = _is_json_serializable(data)
944-
if not isinstance(data, bytes) and not is_json_payload:
945-
raise TypeError(
946-
"`data` is not bytes or json serializable. Set `auto_serialize_data` to `True` to serialize the input data."
947-
)
948-
if model_name and model_version:
949-
header["model-name"] = model_name
950-
header["model-version"] = model_version
951-
elif bool(model_version) ^ bool(model_name):
952-
raise ValueError(
953-
"`model_name` and `model_version` have to be provided together."
922+
if json_input is not None:
923+
if not _is_json_serializable(json_input):
924+
raise ValueError(
925+
"`json_input` must be json serializable. "
926+
"Set `auto_serialize_data` to True, or serialize the provided input data first,"
927+
"or using `data` to pass binary data."
928+
)
929+
utils.get_logger().warning(
930+
"The `json_input` argument of `predict()` will be deprecated soon. "
931+
"Please use `data` argument. "
932+
)
933+
data = json_input
934+
935+
is_json_payload = _is_json_serializable(data)
936+
if not isinstance(data, bytes) and not is_json_payload:
937+
raise TypeError(
938+
"`data` is not bytes or json serializable. Set `auto_serialize_data` to `True` to serialize the input data."
939+
)
940+
if model_name and model_version:
941+
header["model-name"] = model_name
942+
header["model-version"] = model_version
943+
elif bool(model_version) ^ bool(model_name):
944+
raise ValueError(
945+
"`model_name` and `model_version` have to be provided together."
946+
)
947+
prediction = send_request(
948+
data=data,
949+
endpoint=endpoint,
950+
is_json_payload=is_json_payload,
951+
header=header,
954952
)
955-
self._validate_bandwidth(data)
956-
prediction = send_request(
957-
data=data,
958-
endpoint=endpoint,
959-
is_json_payload=is_json_payload,
960-
header=header,
961-
)
962-
return prediction
953+
return prediction
954+
except oci.exceptions.ServiceError as ex:
955+
# When bandwidth exceeds the allocated value, TooManyRequests error (429) will be raised by oci backend.
956+
if ex.status == 429:
957+
bandwidth_mbps = self.infrastructure.bandwidth_mbps or MODEL_DEPLOYMENT_BANDWIDTH_MBPS
958+
utils.get_logger().warning(
959+
f"Load balancer bandwidth exceeds the allocated {bandwidth_mbps} Mbps."
960+
"To estimate the actual bandwidth, use formula: (payload size in KB) * (estimated requests per second) * 8 / 1024."
961+
"To resolve the issue, try sizing down the payload, slowing down the request rate or increasing the allocated bandwidth."
962+
)
963+
raise
963964

964965
def activate(
965966
self,
@@ -1800,45 +1801,6 @@ def _extract_spec_kwargs(self, **kwargs) -> Dict:
18001801
if attribute in kwargs:
18011802
spec_kwargs[attribute] = kwargs[attribute]
18021803
return spec_kwargs
1803-
1804-
def _validate_bandwidth(self, data: Any):
1805-
"""Validates payload size and load balancer bandwidth.
1806-
1807-
Parameters
1808-
----------
1809-
data: Any
1810-
Data or JSON payload for the prediction.
1811-
"""
1812-
payload_size = sys.getsizeof(data)
1813-
if payload_size > MAXIMUM_PAYLOAD_SIZE:
1814-
raise ValueError(
1815-
f"Payload size exceeds the maximum allowed {MAXIMUM_PAYLOAD_SIZE} bytes. Size down the payload."
1816-
)
1817-
1818-
time_now = int(time.time())
1819-
if self.count_start_time == 0:
1820-
self.count_start_time = time_now
1821-
if time_now - self.count_start_time < TIME_FRAME:
1822-
self.request_counter += 1
1823-
else:
1824-
self.estimate_request_per_second = (int)(self.request_counter / TIME_FRAME)
1825-
self.request_counter = 0
1826-
self.count_start_time = 0
1827-
1828-
if not self.infrastructure or not self.runtime:
1829-
raise ValueError("Missing parameter infrastructure or runtime. Try reruning it after parameters are fully configured.")
1830-
1831-
# load balancer bandwidth is only needed for HTTPS mode.
1832-
if self.runtime.deployment_mode == ModelDeploymentMode.HTTPS:
1833-
bandwidth_mbps = self.infrastructure.bandwidth_mbps or MODEL_DEPLOYMENT_BANDWIDTH_MBPS
1834-
# formula: (payload size in KB) * (estimated requests per second) * 8 / 1024
1835-
# 20% extra for estimation errors and sporadic peak traffic
1836-
payload_size_in_kb = payload_size / 1024
1837-
if (payload_size_in_kb * self.estimate_request_per_second * 8 * 1.2) / 1024 > bandwidth_mbps:
1838-
raise ValueError(
1839-
f"Load balancer bandwidth exceeds the allocated {bandwidth_mbps} Mbps."
1840-
"Try sizing down the payload, slowing down the request rate or increasing bandwidth."
1841-
)
18421804

18431805
def build(self) -> "ModelDeployment":
18441806
"""Load default values from the environment for the job infrastructure."""

tests/unitary/default_setup/model_deployment/test_model_deployment_v2.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,10 @@
66

77
import copy
88
from datetime import datetime
9-
import time
109
import oci
1110
import pytest
1211
import unittest
1312
import pandas
14-
import sys
1513
from unittest.mock import MagicMock, patch
1614
from ads.common.oci_datascience import OCIDataScienceMixin
1715
from ads.common.oci_logging import ConsolidatedLog, OCILog
@@ -21,7 +19,6 @@
2119
from ads.model.datascience_model import DataScienceModel
2220

2321
from ads.model.deployment.model_deployment import (
24-
MAXIMUM_PAYLOAD_SIZE,
2522
ModelDeployment,
2623
ModelDeploymentLogType,
2724
ModelDeploymentFailedError,
@@ -1483,33 +1480,3 @@ def test_model_deployment_with_large_size_artifact(
14831480
)
14841481
mock_create_model_deployment.assert_called_with(create_model_deployment_details)
14851482
mock_sync.assert_called()
1486-
1487-
@patch.object(sys, "getsizeof")
1488-
def test_validate_bandwidth(self, mock_get_size_of):
1489-
model_deployment = self.initialize_model_deployment()
1490-
1491-
mock_get_size_of.return_value = 11 * 1024 * 1024
1492-
with pytest.raises(
1493-
ValueError,
1494-
match=f"Payload size exceeds the maximum allowed {MAXIMUM_PAYLOAD_SIZE} bytes. Size down the payload."
1495-
):
1496-
model_deployment._validate_bandwidth("test")
1497-
mock_get_size_of.assert_called()
1498-
1499-
mock_get_size_of.return_value = 9 * 1024 * 1024
1500-
with pytest.raises(
1501-
ValueError,
1502-
match=f"Load balancer bandwidth exceeds the allocated {model_deployment.infrastructure.bandwidth_mbps} Mbps."
1503-
"Try sizing down the payload, slowing down the request rate or increasing bandwidth."
1504-
):
1505-
model_deployment._validate_bandwidth("test")
1506-
mock_get_size_of.assert_called()
1507-
1508-
mock_get_size_of.return_value = 5
1509-
model_deployment._validate_bandwidth("test")
1510-
mock_get_size_of.assert_called()
1511-
1512-
model_deployment.count_start_time = (int)(time.time()) - 700
1513-
model_deployment._validate_bandwidth("test")
1514-
mock_get_size_of.assert_called()
1515-

0 commit comments

Comments
 (0)