Skip to content

Commit 9018bde

Browse files
authored
Added warning for load balancer (#240)
2 parents a4a3d9a + 8207481 commit 9018bde

File tree

3 files changed

+61
-45
lines changed

3 files changed

+61
-45
lines changed

ads/model/deployment/model_deployment.py

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from ads.model.deployment.model_deployment_runtime import (
4141
ModelDeploymentCondaRuntime,
4242
ModelDeploymentContainerRuntime,
43+
ModelDeploymentMode,
4344
ModelDeploymentRuntime,
4445
ModelDeploymentRuntimeType,
4546
OCIModelDeploymentRuntimeType,
@@ -80,11 +81,6 @@ class ModelDeploymentLogType:
8081
ACCESS = "access"
8182

8283

83-
class ModelDeploymentMode:
84-
HTTPS = "HTTPS_ONLY"
85-
STREAM = "STREAM_ONLY"
86-
87-
8884
class LogNotConfiguredError(Exception): # pragma: no cover
8985
pass
9086

@@ -911,48 +907,59 @@ def predict(
911907
"`data` and `json_input` are both provided. You can only use one of them."
912908
)
913909

914-
if auto_serialize_data:
915-
data = data or json_input
916-
serialized_data = serializer.serialize(data=data)
917-
return send_request(
918-
data=serialized_data,
919-
endpoint=endpoint,
920-
is_json_payload=_is_json_serializable(serialized_data),
921-
header=header,
922-
)
910+
try:
911+
if auto_serialize_data:
912+
data = data or json_input
913+
serialized_data = serializer.serialize(data=data)
914+
return send_request(
915+
data=serialized_data,
916+
endpoint=endpoint,
917+
is_json_payload=_is_json_serializable(serialized_data),
918+
header=header,
919+
)
923920

924-
if json_input is not None:
925-
if not _is_json_serializable(json_input):
926-
raise ValueError(
927-
"`json_input` must be json serializable. "
928-
"Set `auto_serialize_data` to True, or serialize the provided input data first,"
929-
"or using `data` to pass binary data."
921+
if json_input is not None:
922+
if not _is_json_serializable(json_input):
923+
raise ValueError(
924+
"`json_input` must be json serializable. "
925+
"Set `auto_serialize_data` to True, or serialize the provided input data first,"
926+
"or using `data` to pass binary data."
927+
)
928+
utils.get_logger().warning(
929+
"The `json_input` argument of `predict()` will be deprecated soon. "
930+
"Please use `data` argument. "
930931
)
931-
utils.get_logger().warning(
932-
"The `json_input` argument of `predict()` will be deprecated soon. "
933-
"Please use `data` argument. "
934-
)
935-
data = json_input
932+
data = json_input
936933

937-
is_json_payload = _is_json_serializable(data)
938-
if not isinstance(data, bytes) and not is_json_payload:
939-
raise TypeError(
940-
"`data` is not bytes or json serializable. Set `auto_serialize_data` to `True` to serialize the input data."
941-
)
942-
if model_name and model_version:
943-
header["model-name"] = model_name
944-
header["model-version"] = model_version
945-
elif bool(model_version) ^ bool(model_name):
946-
raise ValueError(
947-
"`model_name` and `model_version` have to be provided together."
934+
is_json_payload = _is_json_serializable(data)
935+
if not isinstance(data, bytes) and not is_json_payload:
936+
raise TypeError(
937+
"`data` is not bytes or json serializable. Set `auto_serialize_data` to `True` to serialize the input data."
938+
)
939+
if model_name and model_version:
940+
header["model-name"] = model_name
941+
header["model-version"] = model_version
942+
elif bool(model_version) ^ bool(model_name):
943+
raise ValueError(
944+
"`model_name` and `model_version` have to be provided together."
945+
)
946+
prediction = send_request(
947+
data=data,
948+
endpoint=endpoint,
949+
is_json_payload=is_json_payload,
950+
header=header,
948951
)
949-
prediction = send_request(
950-
data=data,
951-
endpoint=endpoint,
952-
is_json_payload=is_json_payload,
953-
header=header,
954-
)
955-
return prediction
952+
return prediction
953+
except oci.exceptions.ServiceError as ex:
954+
# When bandwidth exceeds the allocated value, TooManyRequests error (429) will be raised by oci backend.
955+
if ex.status == 429:
956+
bandwidth_mbps = self.infrastructure.bandwidth_mbps or MODEL_DEPLOYMENT_BANDWIDTH_MBPS
957+
utils.get_logger().warning(
958+
f"Load balancer bandwidth exceeds the allocated {bandwidth_mbps} Mbps."
959+
"To estimate the actual bandwidth, use formula: (payload size in KB) * (estimated requests per second) * 8 / 1024."
960+
"To resolve the issue, try sizing down the payload, slowing down the request rate or increasing the allocated bandwidth."
961+
)
962+
raise
956963

957964
def activate(
958965
self,

ads/model/deployment/model_deployment_runtime.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ class OCIModelDeploymentRuntimeType:
2121
CONTAINER = "OCIR_CONTAINER"
2222

2323

24+
class ModelDeploymentMode:
25+
HTTPS = "HTTPS_ONLY"
26+
STREAM = "STREAM_ONLY"
27+
28+
2429
class ModelDeploymentRuntime(Builder):
2530
"""A class used to represent a Model Deployment Runtime.
2631
@@ -173,7 +178,7 @@ def deployment_mode(self) -> str:
173178
str
174179
The deployment mode of model deployment.
175180
"""
176-
return self.get_spec(self.CONST_DEPLOYMENT_MODE, None)
181+
return self.get_spec(self.CONST_DEPLOYMENT_MODE, ModelDeploymentMode.HTTPS)
177182

178183
def with_deployment_mode(self, deployment_mode: str) -> "ModelDeploymentRuntime":
179184
"""Sets the deployment mode of model deployment.

tests/unitary/default_setup/model_deployment/test_model_deployment.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,17 @@
1414
ModelDeployment,
1515
ModelDeploymentProperties,
1616
)
17+
from ads.model.deployment.model_deployment_infrastructure import ModelDeploymentInfrastructure
18+
from ads.model.deployment.model_deployment_runtime import ModelDeploymentCondaRuntime
1719

1820

1921
class ModelDeploymentTestCase(unittest.TestCase):
2022
MODEL_ID = "<MODEL_OCID>"
2123
with patch.object(oci_client, "OCIClientFactory"):
2224
test_model_deployment = ModelDeployment(
23-
model_deployment_id="test_model_deployment_id", properties={}
25+
model_deployment_id="test_model_deployment_id", properties={},
26+
infrastructure=ModelDeploymentInfrastructure(),
27+
runtime=ModelDeploymentCondaRuntime()
2428
)
2529

2630
@patch("requests.post")

0 commit comments

Comments
 (0)