Skip to content

Commit 2f0758f

Browse files
committed
ODSC-39392: resolve the comments
1 parent 35fe0d2 commit 2f0758f

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

ads/model/deployment/model_deployment.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from ads.common import auth as authutil
1818
import pandas as pd
1919
from ads.model.serde.model_input import JsonModelInputSERDE
20-
from ads.common import auth, oci_client
2120
from ads.common.oci_logging import (
2221
LOG_INTERVAL,
2322
LOG_RECORDS_LIMIT,
@@ -63,6 +62,7 @@
6362

6463
MODEL_DEPLOYMENT_KIND = "deployment"
6564
MODEL_DEPLOYMENT_TYPE = "modelDeployment"
65+
MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON = "TRITON"
6666

6767
MODEL_DEPLOYMENT_INSTANCE_SHAPE = "VM.Standard.E4.Flex"
6868
MODEL_DEPLOYMENT_INSTANCE_OCPUS = 1
@@ -926,10 +926,7 @@ def predict(
926926
if model_name and model_version:
927927
header['model-name'] = model_name
928928
header['model-version'] = model_version
929-
elif not model_version and not model_name:
930-
931-
pass
932-
else:
929+
elif bool(model_version) ^ bool(model_name):
933930
raise ValueError("`model_name` and `model_version` have to be provided together.")
934931
prediction = send_request(
935932
data=data, endpoint=endpoint, is_json_payload=is_json_payload, header=header,
@@ -1404,9 +1401,9 @@ def _update_from_oci_model(self, oci_model_instance) -> "ModelDeployment":
14041401
infrastructure.CONST_WEB_CONCURRENCY,
14051402
runtime.env.get("WEB_CONCURRENCY", None),
14061403
)
1407-
if runtime.env.get("CONTAINER_TYPE", None) == "TRITON":
1404+
if runtime.env.get("CONTAINER_TYPE", None) == MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON:
14081405
runtime.set_spec(
1409-
runtime.CONST_INFERENCE_SERVER, "triton"
1406+
runtime.CONST_INFERENCE_SERVER, MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON.lower()
14101407
)
14111408

14121409
self.set_spec(self.CONST_INFRASTRUCTURE, infrastructure)

ads/model/deployment/model_deployment_runtime.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,9 @@ def with_inference_server(self, inference_server: str = "triton") -> "ModelDeplo
562562
563563
Example
564564
-------
565+
>>> from ads.model.deployment import ModelDeployment, ModelDeploymentContainerRuntime, ModelDeploymentInfrastructure
566+
>>> import ads
567+
>>> ads.set_auth("resource_principal")
565568
>>> infrastructure = ModelDeploymentInfrastructure()\
566569
... .with_project_id(<project_id>)\
567570
... .with_compartment_id(<comparment_id>)\
@@ -578,7 +581,7 @@ def with_inference_server(self, inference_server: str = "triton") -> "ModelDeplo
578581
... .with_model_uri(<model_id>)\
579582
... .with_env({"key":"value", "key2":"value2"})\
580583
... .with_inference_server("triton")
581-
... deployment = ModelDeployment()\
584+
>>> deployment = ModelDeployment()\
582585
... .with_display_name("Triton Example")\
583586
... .with_infrastructure(infrastructure)\
584587
... .with_runtime(runtime)

0 commit comments

Comments
 (0)