Skip to content

Commit 9925742

Browse files
committed
Merge branch 'feature/guardrails' of https://github.com/oracle/accelerated-data-science into add_langchain_deployment
2 parents c12832f + e31bc07 commit 9925742

File tree

4 files changed

+35
-23
lines changed

4 files changed

+35
-23
lines changed

ads/llm/langchain/plugins/llm_md.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any, Dict, List, Optional
99

1010
import requests
11+
from oci.auth import signers
1112
from langchain.callbacks.manager import CallbackManagerForLLMRun
1213

1314
from ads.llm.langchain.plugins.base import BaseLLM
@@ -134,11 +135,22 @@ def send_request(
134135
header.pop("content_type", DEFAULT_CONTENT_TYPE_JSON)
135136
or DEFAULT_CONTENT_TYPE_JSON
136137
)
138+
timeout = kwargs.pop("timeout", DEFAULT_TIME_OUT)
137139
request_kwargs = {"json": data}
138140
request_kwargs["headers"] = header
139-
request_kwargs["auth"] = self.auth.get("signer")
140-
timeout = kwargs.pop("timeout", DEFAULT_TIME_OUT)
141-
response = requests.post(endpoint, timeout=timeout, **request_kwargs, **kwargs)
141+
signer = self.auth.get("signer")
142+
143+
attempts = 0
144+
while attempts < 2:
145+
request_kwargs["auth"] = signer
146+
response = requests.post(
147+
endpoint, timeout=timeout, **request_kwargs, **kwargs
148+
)
149+
if response.status_code == 401 and self.is_principal_signer(signer):
150+
signer.refresh_security_token()
151+
attempts += 1
152+
continue
153+
break
142154

143155
try:
144156
response.raise_for_status()
@@ -155,6 +167,21 @@ def send_request(
155167

156168
return response_json
157169

170+
@staticmethod
171+
def is_principal_signer(signer):
172+
"""Checks if the signer is instance principal or resource principal signer."""
173+
if (
174+
isinstance(signer, signers.InstancePrincipalsSecurityTokenSigner)
175+
or isinstance(signer, signers.ResourcePrincipalsFederationSigner)
176+
or isinstance(signer, signers.EphemeralResourcePrincipalSigner)
177+
or isinstance(signer, signers.EphemeralResourcePrincipalV21Signer)
178+
or isinstance(signer, signers.NestedResourcePrincipals)
179+
or isinstance(signer, signers.OkeWorkloadIdentityResourcePrincipalSigner)
180+
):
181+
return True
182+
else:
183+
return False
184+
158185

159186
class ModelDeploymentTGI(ModelDeploymentLLM):
160187
"""OCI Data Science Model Deployment TGI Endpoint.

ads/llm/patch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from langchain.schema.runnable import RunnableParallel
2-
from langchain.load import dumpd, load
2+
from langchain.load.dump import dumpd
3+
from langchain.load.load import load
34

45

56
class RunnableParallelSerializer:

ads/llm/serialize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from langchain import llms
1515
from langchain.llms import loading
1616
from langchain.chains.loading import load_chain_from_config
17-
from langchain.load.load import Reviver, load as __lc_load
17+
from langchain.load.load import Reviver
1818
from langchain.load.serializable import Serializable
1919

2020
from ads.common.auth import default_signer

ads/llm/templates/score_chain.jinja2

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -132,22 +132,6 @@ def pre_inference(data, input_schema_path):
132132
"""
133133
return deserialize(data, input_schema_path)
134134

135-
136-
def post_inference(yhat):
137-
"""
138-
Post-process the model results
139-
140-
Parameters
141-
----------
142-
yhat: Data format after calling model.predict.
143-
144-
Returns
145-
-------
146-
yhat: Data format after any processing.
147-
148-
"""
149-
return yhat
150-
151135
def predict(data, model=load_model(), input_schema_path=os.path.join(os.path.dirname(os.path.realpath(__file__)), "input_schema.json")):
152136
"""
153137
Returns prediction given the model and data to predict
@@ -165,5 +149,5 @@ def predict(data, model=load_model(), input_schema_path=os.path.join(os.path.dir
165149

166150
"""
167151
features = pre_inference(data, input_schema_path)
168-
yhat = post_inference(model.invoke(features))
169-
return {'prediction': yhat}
152+
output = model.invoke(features)
153+
return {'output': output}

0 commit comments

Comments
 (0)