Skip to content

Commit 61cbba2

Browse files
authored
ODSC 29065/Model Deployment opctl for local dev (#134)
2 parents a46118a + 14d6e3c commit 61cbba2

File tree

15 files changed

+799
-38
lines changed

15 files changed

+799
-38
lines changed

ads/model/deployment/model_deployment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -863,9 +863,9 @@ def predict(
863863
and `json_input` required to be json serializable. If `auto_serialize_data` set
864864
to True, data will be serialized before sending to model deployment endpoint.
865865
model_name: str
866-
Defaults to None. When the `Inference_server="triton"`, the name of the model to invoke.
866+
Defaults to None. When the `inference_server="triton"`, the name of the model to invoke.
867867
model_version: str
868-
Defaults to None. When the `Inference_server="triton"`, the version of the model to invoke.
868+
Defaults to None. When the `inference_server="triton"`, the version of the model to invoke.
869869
kwargs:
870870
content_type: str
871871
Used to indicate the media type of the resource.

ads/opctl/backend/ads_model_deployment.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,16 @@
44
# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7+
import json
78
from typing import Dict
8-
from ads.common.auth import create_signer, AuthContext
9+
10+
from oci.data_science.models import ModelDeployment as OCIModelDeployment
11+
12+
import ads
13+
from ads.common.auth import AuthContext, create_signer
914
from ads.common.oci_client import OCIClientFactory
10-
from ads.opctl.backend.base import Backend
1115
from ads.model.deployment import ModelDeployment
16+
from ads.opctl.backend.base import Backend
1217

1318

1419
class ModelDeploymentBackend(Backend):
@@ -117,3 +122,16 @@ def watch(self) -> None:
117122
model_deployment.watch(
118123
log_type=log_type, interval=interval, log_filter=log_filter
119124
)
125+
126+
def predict(self) -> None:
127+
ocid = self.config["execution"].get("ocid")
128+
data = self.config["execution"].get("payload")
129+
model_name = self.config["execution"].get("model_name")
130+
model_version = self.config["execution"].get("model_version")
131+
with AuthContext(auth=self.auth_type, profile=self.profile):
132+
model_deployment = ModelDeployment.from_id(ocid)
133+
try:
134+
data = json.loads(data)
135+
except:
136+
pass
137+
print(model_deployment.predict(data=data, model_name=model_name, model_version=model_version))

ads/opctl/backend/base.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,19 @@
77
from abc import abstractmethod
88
from typing import Dict
99

10-
from ads.common.auth import get_signer
10+
from ads.common.auth import create_signer
11+
1112

1213

1314
class Backend:
1415
"""Interface for backend"""
1516

1617
def __init__(self, config: Dict) -> None:
1718
self.config = config
18-
self.oci_auth = get_signer(
19-
config["execution"].get("oci_config", None),
20-
config["execution"].get("oci_profile", None),
21-
)
19+
self.auth_type = config["execution"].get("auth")
2220
self.profile = config["execution"].get("oci_profile", None)
21+
self.oci_config = config["execution"].get("oci_config", None)
22+
2323

2424
@abstractmethod
2525
def run(self) -> Dict:
@@ -91,3 +91,13 @@ def run_diagnostics(self):
9191
"""
9292
Implement Diagnostics check appropriate for the backend
9393
"""
94+
95+
def predict(self) -> None:
96+
"""
97+
Run model predict.
98+
99+
Returns
100+
-------
101+
None
102+
"""
103+
raise NotImplementedError("`predict` has not been implemented yet.")

0 commit comments

Comments
 (0)