Skip to content

Commit 01b5606

Browse files
committed
Added auto validation for load balancer.
1 parent 7f49d7e commit 01b5606

File tree

2 files changed

+82
-0
lines changed

2 files changed

+82
-0
lines changed

ads/model/deployment/model_deployment.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import collections
99
import copy
1010
import datetime
11+
import sys
1112
import oci
1213
import warnings
1314
import time
@@ -70,6 +71,9 @@
7071
MODEL_DEPLOYMENT_INSTANCE_COUNT = 1
7172
MODEL_DEPLOYMENT_BANDWIDTH_MBPS = 10
7273

74+
TIME_FRAME = 60
75+
MAXIMUM_PAYLOAD_SIZE = 10 * 1024 * 1024 # bytes
76+
7377
MODEL_DEPLOYMENT_RUNTIMES = {
7478
ModelDeploymentRuntimeType.CONDA: ModelDeploymentCondaRuntime,
7579
ModelDeploymentRuntimeType.CONTAINER: ModelDeploymentContainerRuntime,
@@ -253,6 +257,10 @@ class ModelDeployment(Builder):
253257
CONST_TIME_CREATED: "time_created",
254258
}
255259

260+
count_start_time = 0
261+
request_counter = 0
262+
estimate_request_per_second = 100
263+
256264
initialize_spec_attributes = [
257265
"display_name",
258266
"description",
@@ -911,6 +919,8 @@ def predict(
911919
raise AttributeError(
912920
"`data` and `json_input` are both provided. You can only use one of them."
913921
)
922+
923+
self._validate_bandwidth(data or json_input)
914924

915925
if auto_serialize_data:
916926
data = data or json_input
@@ -1794,6 +1804,45 @@ def _extract_spec_kwargs(self, **kwargs) -> Dict:
17941804
if attribute in kwargs:
17951805
spec_kwargs[attribute] = kwargs[attribute]
17961806
return spec_kwargs
1807+
1808+
def _validate_bandwidth(self, data: Any):
1809+
"""Validates payload size and load balancer bandwidth.
1810+
1811+
Parameters
1812+
----------
1813+
data: Any
1814+
Data or JSON payload for the prediction.
1815+
"""
1816+
payload_size = sys.getsizeof(data)
1817+
if payload_size > MAXIMUM_PAYLOAD_SIZE:
1818+
raise ValueError(
1819+
f"Payload size exceeds the maximum allowed {MAXIMUM_PAYLOAD_SIZE} bytes. Size down the payload."
1820+
)
1821+
1822+
time_now = int(time.time())
1823+
if self.count_start_time == 0:
1824+
self.count_start_time = time_now
1825+
if time_now - self.count_start_time < TIME_FRAME:
1826+
self.request_counter += 1
1827+
else:
1828+
self.estimate_request_per_second = (int)(self.request_counter / TIME_FRAME)
1829+
self.request_counter = 0
1830+
self.count_start_time = 0
1831+
1832+
if not self.infrastructure or not self.runtime:
1833+
raise ValueError("Missing parameter infrastructure or runtime. Try reruning it after parameters are fully configured.")
1834+
1835+
# load balancer bandwidth is only needed for HTTPS mode.
1836+
if self.runtime.deployment_mode == ModelDeploymentMode.HTTPS:
1837+
bandwidth_mbps = self.infrastructure.bandwidth_mbps or MODEL_DEPLOYMENT_BANDWIDTH_MBPS
1838+
# formula: (payload size in KB) * (estimated requests per second) * 8 / 1024
1839+
# 20% extra for estimation errors and sporadic peak traffic
1840+
payload_size_in_kb = payload_size / 1024
1841+
if (payload_size_in_kb * self.estimate_request_per_second * 8 * 1.2) / 1024 > bandwidth_mbps:
1842+
raise ValueError(
1843+
f"Load balancer bandwidth exceeds the allocated {bandwidth_mbps} Mbps."
1844+
"Try sizing down the payload, slowing down the request rate or increasing bandwidth."
1845+
)
17971846

17981847
def build(self) -> "ModelDeployment":
17991848
"""Load default values from the environment for the job infrastructure."""

tests/unitary/default_setup/model_deployment/test_model_deployment_v2.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66

77
import copy
88
from datetime import datetime
9+
import time
910
import oci
1011
import pytest
1112
import unittest
1213
import pandas
14+
import sys
1315
from unittest.mock import MagicMock, patch
1416
from ads.common.oci_datascience import OCIDataScienceMixin
1517
from ads.common.oci_logging import ConsolidatedLog, OCILog
@@ -19,6 +21,7 @@
1921
from ads.model.datascience_model import DataScienceModel
2022

2123
from ads.model.deployment.model_deployment import (
24+
MAXIMUM_PAYLOAD_SIZE,
2225
ModelDeployment,
2326
ModelDeploymentLogType,
2427
ModelDeploymentFailedError,
@@ -1480,3 +1483,33 @@ def test_model_deployment_with_large_size_artifact(
14801483
)
14811484
mock_create_model_deployment.assert_called_with(create_model_deployment_details)
14821485
mock_sync.assert_called()
1486+
1487+
@patch.object(sys, "getsizeof")
1488+
def test_validate_bandwidth(self, mock_get_size_of):
1489+
model_deployment = self.initialize_model_deployment()
1490+
1491+
mock_get_size_of.return_value = 11 * 1024 * 1024
1492+
with pytest.raises(
1493+
ValueError,
1494+
match=f"Payload size exceeds the maximum allowed {MAXIMUM_PAYLOAD_SIZE} bytes. Size down the payload."
1495+
):
1496+
model_deployment._validate_bandwidth("test")
1497+
mock_get_size_of.assert_called()
1498+
1499+
mock_get_size_of.return_value = 9 * 1024 * 1024
1500+
with pytest.raises(
1501+
ValueError,
1502+
match=f"Load balancer bandwidth exceeds the allocated {model_deployment.infrastructure.bandwidth_mbps} Mbps."
1503+
"Try sizing down the payload, slowing down the request rate or increasing bandwidth."
1504+
):
1505+
model_deployment._validate_bandwidth("test")
1506+
mock_get_size_of.assert_called()
1507+
1508+
mock_get_size_of.return_value = 5
1509+
model_deployment._validate_bandwidth("test")
1510+
mock_get_size_of.assert_called()
1511+
1512+
model_deployment.count_start_time = (int)(time.time()) - 700
1513+
model_deployment._validate_bandwidth("test")
1514+
mock_get_size_of.assert_called()
1515+

0 commit comments

Comments
 (0)