Skip to content

Commit 98e11ff

Browse files
committed
Fixed unit tests.
1 parent 01b5606 commit 98e11ff

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

ads/model/deployment/model_deployment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -920,11 +920,10 @@ def predict(
920920
"`data` and `json_input` are both provided. You can only use one of them."
921921
)
922922

923-
self._validate_bandwidth(data or json_input)
924-
925923
if auto_serialize_data:
926924
data = data or json_input
927925
serialized_data = serializer.serialize(data=data)
926+
self._validate_bandwidth(serialized_data)
928927
return send_request(
929928
data=serialized_data,
930929
endpoint=endpoint,
@@ -957,6 +956,7 @@ def predict(
957956
raise ValueError(
958957
"`model_name` and `model_version` have to be provided together."
959958
)
959+
self._validate_bandwidth(data)
960960
prediction = send_request(
961961
data=data,
962962
endpoint=endpoint,

tests/unitary/default_setup/model_deployment/test_model_deployment.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,17 @@
1414
ModelDeployment,
1515
ModelDeploymentProperties,
1616
)
17+
from ads.model.deployment.model_deployment_infrastructure import ModelDeploymentInfrastructure
18+
from ads.model.deployment.model_deployment_runtime import ModelDeploymentCondaRuntime
1719

1820

1921
class ModelDeploymentTestCase(unittest.TestCase):
2022
MODEL_ID = "<MODEL_OCID>"
2123
with patch.object(oci_client, "OCIClientFactory"):
2224
test_model_deployment = ModelDeployment(
23-
model_deployment_id="test_model_deployment_id", properties={}
25+
model_deployment_id="test_model_deployment_id", properties={},
26+
infrastructure=ModelDeploymentInfrastructure(),
27+
runtime=ModelDeploymentCondaRuntime()
2428
)
2529

2630
@patch("requests.post")

0 commit comments

Comments
 (0)