Skip to content

Commit 6b36e98

Browse files
committed
Merge branch 'feature/guardrails' of https://github.com/oracle/accelerated-data-science into add_langchain_deployment
2 parents 9925742 + 5859927 commit 6b36e98

File tree

14 files changed

+655
-120
lines changed

14 files changed

+655
-120
lines changed

ads/llm/__init__.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,15 @@
44
# Copyright (c) 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7-
8-
from ads.llm.langchain.plugins.llm_gen_ai import GenerativeAI
9-
from ads.llm.langchain.plugins.llm_md import ModelDeploymentTGI
10-
from ads.llm.langchain.plugins.llm_md import ModelDeploymentVLLM
11-
from ads.llm.langchain.plugins.embeddings import GenerativeAIEmbeddings
7+
try:
8+
import langchain
9+
from ads.llm.langchain.plugins.llm_gen_ai import GenerativeAI
10+
from ads.llm.langchain.plugins.llm_md import ModelDeploymentTGI
11+
from ads.llm.langchain.plugins.llm_md import ModelDeploymentVLLM
12+
from ads.llm.langchain.plugins.embeddings import GenerativeAIEmbeddings
13+
except ImportError as ex:
14+
if ex.name == "langchain":
15+
raise ImportError(
16+
f"{ex.msg}\nPlease install/update langchain with `pip install langchain -U`"
17+
) from ex
18+
raise ex

ads/llm/chain.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,15 +192,15 @@ def _save_to_file(self, chain_dict, filename, overwrite=False):
192192
)
193193

194194
file_ext = pathlib.Path(expanded_path).suffix.lower()
195+
if file_ext not in [".yaml", ".json"]:
196+
raise ValueError(
197+
f"{self.__class__.__name__} can only be saved as yaml or json format."
198+
)
195199
with open(expanded_path, "w", encoding="utf-8") as f:
196200
if file_ext == ".yaml":
197201
yaml.safe_dump(chain_dict, f, default_flow_style=False)
198202
elif file_ext == ".json":
199203
json.dump(chain_dict, f)
200-
else:
201-
raise ValueError(
202-
f"{self.__class__.__name__} can only be saved as yaml or json format."
203-
)
204204

205205
def save(self, filename: str = None, overwrite: bool = False) -> dict:
206206
"""Serialize the sequence to a dictionary.

ads/llm/langchain/plugins/base.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,13 @@
66
from typing import Any, Dict, List, Optional
77

88
from langchain.llms.base import LLM
9-
from langchain.load.serializable import Serializable
109
from langchain.pydantic_v1 import BaseModel, Field, root_validator
1110

1211
from ads.common.auth import default_signer
1312
from ads.config import COMPARTMENT_OCID
1413

1514

16-
class BaseLLM(LLM, Serializable):
15+
class BaseLLM(LLM):
1716
"""Base OCI LLM class. Contains common attributes."""
1817

1918
auth: dict = Field(default_factory=default_signer, exclude=True)
@@ -58,6 +57,14 @@ def is_lc_serializable(cls) -> bool:
5857
class GenerativeAiClientModel(BaseModel):
5958
"""Base model for generative AI embedding model and LLM."""
6059

60+
# This auth is the same as the auth in BaseLLM class.
61+
# However, this is needed for the Gen AI embedding model.
62+
# Do not remove this attribute
63+
auth: dict = Field(default_factory=default_signer, exclude=True)
64+
"""ADS auth dictionary for OCI authentication.
65+
This can be generated by calling `ads.common.auth.api_keys()` or `ads.common.auth.resource_principal()`.
66+
If this is not provided then the `ads.common.default_signer()` will be used."""
67+
6168
client: Any #: :meta private:
6269
"""OCI GenerativeAiClient."""
6370

ads/llm/langchain/plugins/llm_md.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def _construct_json_body(self, prompt, params):
234234
}
235235

236236
def _process_response(self, response_json: dict):
237-
return str(response_json.get("generated_text", response_json)) + "\n"
237+
return str(response_json.get("generated_text", response_json))
238238

239239

240240
class ModelDeploymentVLLM(ModelDeploymentLLM):

ads/llm/parsers.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

ads/llm/serialize.py

Lines changed: 127 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,32 @@
44
# Copyright (c) 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7+
import base64
78
import json
89
import os
910
import tempfile
11+
from copy import deepcopy
1012
from typing import Any, Dict, List, Optional
1113

1214
import fsspec
1315
import yaml
1416
from langchain import llms
15-
from langchain.llms import loading
17+
from langchain.chains import RetrievalQA
1618
from langchain.chains.loading import load_chain_from_config
19+
from langchain.llms import loading
20+
from langchain.load import dumpd
1721
from langchain.load.load import Reviver
1822
from langchain.load.serializable import Serializable
23+
from langchain.vectorstores import FAISS, OpenSearchVectorSearch
24+
from opensearchpy.client import OpenSearch
1925

2026
from ads.common.auth import default_signer
2127
from ads.common.object_storage_details import ObjectStorageDetails
22-
from ads.llm import GenerativeAI, ModelDeploymentVLLM, ModelDeploymentTGI
28+
from ads.llm import GenerativeAI, ModelDeploymentTGI, ModelDeploymentVLLM
2329
from ads.llm.chain import GuardrailSequence
2430
from ads.llm.guardrails.base import CustomGuardrailBase
2531
from ads.llm.patch import RunnableParallel, RunnableParallelSerializer
2632

27-
2833
# This is a temp solution for supporting custom LLM in legacy load_chain
2934
__lc_llm_dict = llms.get_type_to_cls_dict()
3035
__lc_llm_dict[GenerativeAI.__name__] = lambda: GenerativeAI
@@ -39,11 +44,129 @@ def __new_type_to_cls_dict():
3944
llms.get_type_to_cls_dict = __new_type_to_cls_dict
4045
loading.get_type_to_cls_dict = __new_type_to_cls_dict
4146

47+
48+
class OpenSearchVectorDBSerializer:
49+
"""
50+
Serializer for OpenSearchVectorSearch class
51+
"""
52+
@staticmethod
53+
def type():
54+
return OpenSearchVectorSearch.__name__
55+
56+
@staticmethod
57+
def load(config: dict, **kwargs):
58+
config["kwargs"]["embedding_function"] = load(
59+
config["kwargs"]["embedding_function"], **kwargs
60+
)
61+
return OpenSearchVectorSearch(
62+
**config["kwargs"],
63+
http_auth=(
64+
os.environ.get("OCI_OPENSEARCH_USERNAME", None),
65+
os.environ.get("OCI_OPENSEARCH_PASSWORD", None),
66+
),
67+
verify_certs=True if os.environ.get("OCI_OPENSEARCH_VERIFY_CERTS", None).lower() == "true" else False,
68+
ca_certs=os.environ.get("OCI_OPENSEARCH_CA_CERTS", None),
69+
)
70+
71+
@staticmethod
72+
def save(obj):
73+
serialized = dumpd(obj)
74+
serialized["type"] = "constructor"
75+
serialized["_type"] = OpenSearchVectorDBSerializer.type()
76+
kwargs = {}
77+
for key, val in obj.__dict__.items():
78+
if key == "client":
79+
if isinstance(val, OpenSearch):
80+
client_info = val.transport.hosts[0]
81+
opensearch_url = (
82+
f"https://{client_info['host']}:{client_info['port']}"
83+
)
84+
kwargs.update({"opensearch_url": opensearch_url})
85+
else:
86+
raise NotImplementedError("Only support OpenSearch client.")
87+
continue
88+
kwargs[key] = dump(val)
89+
serialized["kwargs"] = kwargs
90+
return serialized
91+
92+
93+
class FaissSerializer:
94+
"""
95+
Serializer for OpenSearchVectorSearch class
96+
"""
97+
@staticmethod
98+
def type():
99+
return FAISS.__name__
100+
101+
@staticmethod
102+
def load(config: dict, **kwargs):
103+
embedding_function = load(config["embedding_function"], **kwargs)
104+
decoded_pkl = base64.b64decode(json.loads(config["vectordb"]))
105+
return FAISS.deserialize_from_bytes(
106+
embeddings=embedding_function, serialized=decoded_pkl
107+
) # Load the index
108+
109+
@staticmethod
110+
def save(obj):
111+
serialized = {}
112+
serialized["_type"] = FaissSerializer.type()
113+
pkl = obj.serialize_to_bytes()
114+
# Encoding bytes to a base64 string
115+
encoded_pkl = base64.b64encode(pkl).decode('utf-8')
116+
# Serializing the base64 string
117+
serialized["vectordb"] = json.dumps(encoded_pkl)
118+
serialized["embedding_function"] = dump(obj.__dict__["embedding_function"])
119+
return serialized
120+
121+
# Mapping class to vector store serialization functions
122+
vectordb_serialization = {"OpenSearchVectorSearch": OpenSearchVectorDBSerializer, "FAISS": FaissSerializer}
123+
124+
125+
class RetrievalQASerializer:
126+
"""
127+
Serializer for RetrieverQA class
128+
"""
129+
@staticmethod
130+
def type():
131+
return "retrieval_qa"
132+
133+
@staticmethod
134+
def load(config: dict, **kwargs):
135+
config_param = deepcopy(config)
136+
retriever_kwargs = config_param.pop("retriever_kwargs")
137+
vectordb_serializer = vectordb_serialization[config_param["vectordb"]["class"]]
138+
vectordb = vectordb_serializer.load(config_param.pop("vectordb"), **kwargs)
139+
retriever = vectordb.as_retriever(**retriever_kwargs)
140+
return load_chain_from_config(config=config_param, retriever=retriever)
141+
142+
@staticmethod
143+
def save(obj):
144+
serialized = obj.dict()
145+
retriever_kwargs = {}
146+
for key, val in obj.retriever.__dict__.items():
147+
if key not in ["tags", "metadata", "vectorstore"]:
148+
retriever_kwargs[key] = val
149+
serialized["retriever_kwargs"] = retriever_kwargs
150+
serialized["vectordb"] = {"class": obj.retriever.vectorstore.__class__.__name__}
151+
152+
vectordb_serializer = vectordb_serialization[serialized["vectordb"]["class"]]
153+
serialized["vectordb"].update(
154+
vectordb_serializer.save(obj.retriever.vectorstore)
155+
)
156+
157+
if serialized["vectordb"]["class"] not in vectordb_serialization:
158+
raise NotImplementedError(
159+
f"VectorDBSerializer for {serialized['vectordb']['class']} is not implemented."
160+
)
161+
return serialized
162+
163+
42164
# Mapping class to custom serialization functions
43165
custom_serialization = {
44166
GuardrailSequence: GuardrailSequence.save,
45167
CustomGuardrailBase: CustomGuardrailBase.save,
46168
RunnableParallel: RunnableParallelSerializer.save,
169+
RetrievalQA: RetrievalQASerializer.save,
47170
}
48171

49172
# Mapping _type to custom deserialization functions
@@ -52,6 +175,7 @@ def __new_type_to_cls_dict():
52175
GuardrailSequence.type(): GuardrailSequence.load,
53176
CustomGuardrailBase.type(): CustomGuardrailBase.load,
54177
RunnableParallelSerializer.type(): RunnableParallelSerializer.load,
178+
RetrievalQASerializer.type(): RetrievalQASerializer.load,
55179
}
56180

57181

ads/llm/templates/score_chain.jinja2

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,4 +150,7 @@ def predict(data, model=load_model(), input_schema_path=os.path.join(os.path.dir
150150
"""
151151
features = pre_inference(data, input_schema_path)
152152
output = model.invoke(features)
153+
# Return the output as is if the output is a dictionary
154+
if isinstance(output, dict):
155+
return output
153156
return {'output': output}

dev-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-r test-requirements.txt
2-
-e ".[bds,data,geo,huggingface,notebook,onnx,opctl,optuna,pii,spark,tensorflow,text,torch,viz]"
2+
-e ".[bds,data,geo,huggingface,llm,notebook,onnx,opctl,optuna,pii,spark,tensorflow,text,torch,viz]"
33
arff
44
category_encoders
55
dask

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ pii = [
181181
"spacy-transformers==1.2.5",
182182
"spacy==3.6.1",
183183
]
184+
llm = [
185+
"langchain>=0.0.295"
186+
]
184187

185188
[project.urls]
186189
"Github" = "https://github.com/oracle/accelerated-data-science"

tests/unitary/with_extras/langchain/test_guardrails.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
# Copyright (c) 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7+
import json
78
import os
9+
import tempfile
810
from typing import Any, List, Dict, Mapping, Optional
911
from unittest import TestCase
1012
from langchain.callbacks.manager import CallbackManagerForLLMRun
@@ -146,8 +148,8 @@ def test_guardrail_sequence_with_template_and_toxicity(self):
146148
def test_fn(chain: GuardrailSequence):
147149
output = chain.run("cats", num_generations=5)
148150
self.assertIsInstance(output, GuardrailIO)
149-
self.assertIsInstance(output.data, list)
150-
self.assertEqual(len(output.data), 1)
151+
self.assertIsInstance(output.data, str)
152+
self.assertEqual(output.data, "Tell me a joke about cats")
151153
self.assertIsInstance(output.info, list)
152154
self.assertEqual(len(output.info), len(chain.steps))
153155

@@ -166,10 +168,37 @@ def test_guardrail_sequence_with_filtering(self):
166168
def test_fn(chain: GuardrailSequence):
167169
output = chain.run(self.TOXIC_CONTENT)
168170
self.assertIsInstance(output, GuardrailIO)
169-
self.assertIsInstance(output.data, list)
170-
self.assertEqual(len(output.data), 1)
171-
self.assertEqual(output.data[0], message)
171+
self.assertIsInstance(output.data, str)
172+
self.assertEqual(output.data, message)
172173
self.assertIsInstance(output.info, list)
173174
self.assertEqual(len(output.info), len(chain.steps))
174175

175176
self.assert_before_and_after_serialization(test_fn, chain)
177+
178+
def test_empty_sequence(self):
179+
"""Tests empty sequence."""
180+
seq = GuardrailSequence()
181+
self.assertEqual(seq.steps, [])
182+
183+
def test_save_to_file(self):
184+
"""Tests saving to file."""
185+
message = "Let's talk something else."
186+
toxicity = HuggingFaceEvaluation(
187+
path="toxicity",
188+
load_args=self.LOAD_ARGS,
189+
threshold=0.5,
190+
custom_msg=message,
191+
)
192+
chain = GuardrailSequence.from_sequence(self.FAKE_LLM | toxicity)
193+
try:
194+
temp = tempfile.NamedTemporaryFile(suffix=".json", delete=False)
195+
temp.close()
196+
with self.assertRaises(FileExistsError):
197+
serialized = chain.save(temp.name)
198+
with self.assertRaises(ValueError):
199+
chain.save("abc.html")
200+
serialized = chain.save(temp.name, overwrite=True)
201+
with open(temp.name, "r", encoding="utf-8") as f:
202+
self.assertEqual(json.load(f), serialized)
203+
finally:
204+
os.unlink(temp.name)

0 commit comments

Comments
 (0)