Skip to content

Commit ea3e450

Browse files
committed
resolve conflict
2 parents 8b83df0 + 2c7a77f commit ea3e450

File tree

4 files changed

+60
-23
lines changed

4 files changed

+60
-23
lines changed

ads/llm/langchain/plugins/llm_md.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any, Dict, List, Optional
99

1010
import requests
11+
from oci.auth import signers
1112
from langchain.callbacks.manager import CallbackManagerForLLMRun
1213

1314
from ads.llm.langchain.plugins.base import BaseLLM
@@ -134,11 +135,22 @@ def send_request(
134135
header.pop("content_type", DEFAULT_CONTENT_TYPE_JSON)
135136
or DEFAULT_CONTENT_TYPE_JSON
136137
)
138+
timeout = kwargs.pop("timeout", DEFAULT_TIME_OUT)
137139
request_kwargs = {"json": data}
138140
request_kwargs["headers"] = header
139-
request_kwargs["auth"] = self.auth.get("signer")
140-
timeout = kwargs.pop("timeout", DEFAULT_TIME_OUT)
141-
response = requests.post(endpoint, timeout=timeout, **request_kwargs, **kwargs)
141+
signer = self.auth.get("signer")
142+
143+
attempts = 0
144+
while attempts < 2:
145+
request_kwargs["auth"] = signer
146+
response = requests.post(
147+
endpoint, timeout=timeout, **request_kwargs, **kwargs
148+
)
149+
if response.status_code == 401 and self.is_principal_signer(signer):
150+
signer.refresh_security_token()
151+
attempts += 1
152+
continue
153+
break
142154

143155
try:
144156
response.raise_for_status()
@@ -155,6 +167,21 @@ def send_request(
155167

156168
return response_json
157169

170+
@staticmethod
171+
def is_principal_signer(signer):
172+
"""Checks if the signer is instance principal or resource principal signer."""
173+
if (
174+
isinstance(signer, signers.InstancePrincipalsSecurityTokenSigner)
175+
or isinstance(signer, signers.ResourcePrincipalsFederationSigner)
176+
or isinstance(signer, signers.EphemeralResourcePrincipalSigner)
177+
or isinstance(signer, signers.EphemeralResourcePrincipalV21Signer)
178+
or isinstance(signer, signers.NestedResourcePrincipals)
179+
or isinstance(signer, signers.OkeWorkloadIdentityResourcePrincipalSigner)
180+
):
181+
return True
182+
else:
183+
return False
184+
158185

159186
class ModelDeploymentTGI(ModelDeploymentLLM):
160187
"""OCI Data Science Model Deployment TGI Endpoint.

ads/llm/patch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from langchain.schema.runnable import RunnableParallel
2-
from langchain.load import dumpd, load
2+
from langchain.load.dump import dumpd
3+
from langchain.load.load import load
34

45

56
class RunnableParallelSerializer:

ads/llm/serialize.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from langchain.llms import loading
2020
from langchain.load import dumpd
2121
from langchain.load.load import load as lc_load
22+
from langchain.load.load import Reviver
2223
from langchain.load.serializable import Serializable
2324
from langchain.vectorstores import FAISS, OpenSearchVectorSearch
2425
from opensearchpy.client import OpenSearch
@@ -200,14 +201,32 @@ def load(
200201
Returns:
201202
Revived LangChain objects.
202203
"""
204+
# Add ADS as valid namespace
203205
if not valid_namespaces:
204206
valid_namespaces = []
205207
if "ads" not in valid_namespaces:
206208
valid_namespaces.append("ads")
207209

210+
reviver = Reviver(secrets_map, valid_namespaces)
211+
212+
def _load(obj: Any) -> Any:
213+
if isinstance(obj, dict):
214+
if "_type" in obj and obj["_type"] in custom_deserialization:
215+
if valid_namespaces:
216+
kwargs["valid_namespaces"] = valid_namespaces
217+
if secrets_map:
218+
kwargs["secret_map"] = secrets_map
219+
return custom_deserialization[obj["_type"]](obj, **kwargs)
220+
# Need to revive leaf nodes before reviving this node
221+
loaded_obj = {k: _load(v) for k, v in obj.items()}
222+
return reviver(loaded_obj)
223+
if isinstance(obj, list):
224+
return [_load(o) for o in obj]
225+
return obj
226+
208227
if isinstance(obj, dict) and "_type" in obj:
209228
obj_type = obj["_type"]
210-
# Check if the object requires a custom function to load.
229+
# Check if the object has custom load function.
211230
if obj_type in custom_deserialization:
212231
if valid_namespaces:
213232
kwargs["valid_namespaces"] = valid_namespaces
@@ -268,6 +287,9 @@ def default(obj: Any) -> Any:
268287
TypeError
269288
If the object is not LangChain serializable.
270289
"""
290+
for super_class, save_fn in custom_serialization.items():
291+
if isinstance(obj, super_class):
292+
return save_fn(obj)
271293
if isinstance(obj, Serializable) and obj.is_lc_serializable():
272294
return obj.to_json()
273295
raise TypeError(f"Serialization of {type(obj)} is not supported.")

ads/llm/templates/score_chain.jinja2

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -132,22 +132,6 @@ def pre_inference(data, input_schema_path):
132132
"""
133133
return deserialize(data, input_schema_path)
134134

135-
136-
def post_inference(yhat):
137-
"""
138-
Post-process the model results
139-
140-
Parameters
141-
----------
142-
yhat: Data format after calling model.predict.
143-
144-
Returns
145-
-------
146-
yhat: Data format after any processing.
147-
148-
"""
149-
return yhat
150-
151135
def predict(data, model=load_model(), input_schema_path=os.path.join(os.path.dirname(os.path.realpath(__file__)), "input_schema.json")):
152136
"""
153137
Returns prediction given the model and data to predict
@@ -165,5 +149,8 @@ def predict(data, model=load_model(), input_schema_path=os.path.join(os.path.dir
165149

166150
"""
167151
features = pre_inference(data, input_schema_path)
168-
yhat = post_inference(model.run(features))
169-
return {'prediction': yhat}
152+
output = model.invoke(features)
153+
# Return the output as is if the output is a dictionary
154+
if isinstance(output, dict):
155+
return output
156+
return {'output': output}

0 commit comments

Comments
 (0)