Skip to content

Commit 0315d8b

Browse files
authored
Support OpenSearchVectorSearch and FAISS serializer for vector store (#461)
2 parents 7b36274 + 3599bc1 commit 0315d8b

File tree

2 files changed

+288
-3
lines changed

2 files changed

+288
-3
lines changed

ads/llm/serialize.py

Lines changed: 127 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,32 @@
44
# Copyright (c) 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7+
import base64
78
import json
89
import os
910
import tempfile
11+
from copy import deepcopy
1012
from typing import Any, Dict, List, Optional
1113

1214
import fsspec
1315
import yaml
1416
from langchain import llms
15-
from langchain.llms import loading
17+
from langchain.chains import RetrievalQA
1618
from langchain.chains.loading import load_chain_from_config
19+
from langchain.llms import loading
20+
from langchain.load import dumpd
1721
from langchain.load.load import Reviver
1822
from langchain.load.serializable import Serializable
23+
from langchain.vectorstores import FAISS, OpenSearchVectorSearch
24+
from opensearchpy.client import OpenSearch
1925

2026
from ads.common.auth import default_signer
2127
from ads.common.object_storage_details import ObjectStorageDetails
22-
from ads.llm import GenerativeAI, ModelDeploymentVLLM, ModelDeploymentTGI
28+
from ads.llm import GenerativeAI, ModelDeploymentTGI, ModelDeploymentVLLM
2329
from ads.llm.chain import GuardrailSequence
2430
from ads.llm.guardrails.base import CustomGuardrailBase
2531
from ads.llm.patch import RunnableParallel, RunnableParallelSerializer
2632

27-
2833
# This is a temp solution for supporting custom LLM in legacy load_chain
2934
__lc_llm_dict = llms.get_type_to_cls_dict()
3035
__lc_llm_dict[GenerativeAI.__name__] = lambda: GenerativeAI
@@ -39,11 +44,129 @@ def __new_type_to_cls_dict():
3944
llms.get_type_to_cls_dict = __new_type_to_cls_dict
4045
loading.get_type_to_cls_dict = __new_type_to_cls_dict
4146

47+
48+
class OpenSearchVectorDBSerializer:
49+
"""
50+
Serializer for OpenSearchVectorSearch class
51+
"""
52+
@staticmethod
53+
def type():
54+
return OpenSearchVectorSearch.__name__
55+
56+
@staticmethod
57+
def load(config: dict, **kwargs):
58+
config["kwargs"]["embedding_function"] = load(
59+
config["kwargs"]["embedding_function"], **kwargs
60+
)
61+
return OpenSearchVectorSearch(
62+
**config["kwargs"],
63+
http_auth=(
64+
os.environ.get("OCI_OPENSEARCH_USERNAME", None),
65+
os.environ.get("OCI_OPENSEARCH_PASSWORD", None),
66+
),
67+
verify_certs=True if os.environ.get("OCI_OPENSEARCH_VERIFY_CERTS", None).lower() == "true" else False,
68+
ca_certs=os.environ.get("OCI_OPENSEARCH_CA_CERTS", None),
69+
)
70+
71+
@staticmethod
72+
def save(obj):
73+
serialized = dumpd(obj)
74+
serialized["type"] = "constructor"
75+
serialized["_type"] = OpenSearchVectorDBSerializer.type()
76+
kwargs = {}
77+
for key, val in obj.__dict__.items():
78+
if key == "client":
79+
if isinstance(val, OpenSearch):
80+
client_info = val.transport.hosts[0]
81+
opensearch_url = (
82+
f"https://{client_info['host']}:{client_info['port']}"
83+
)
84+
kwargs.update({"opensearch_url": opensearch_url})
85+
else:
86+
raise NotImplementedError("Only support OpenSearch client.")
87+
continue
88+
kwargs[key] = dump(val)
89+
serialized["kwargs"] = kwargs
90+
return serialized
91+
92+
93+
class FaissSerializer:
94+
"""
95+
Serializer for OpenSearchVectorSearch class
96+
"""
97+
@staticmethod
98+
def type():
99+
return FAISS.__name__
100+
101+
@staticmethod
102+
def load(config: dict, **kwargs):
103+
embedding_function = load(config["embedding_function"], **kwargs)
104+
decoded_pkl = base64.b64decode(json.loads(config["vectordb"]))
105+
return FAISS.deserialize_from_bytes(
106+
embeddings=embedding_function, serialized=decoded_pkl
107+
) # Load the index
108+
109+
@staticmethod
110+
def save(obj):
111+
serialized = {}
112+
serialized["_type"] = FaissSerializer.type()
113+
pkl = obj.serialize_to_bytes()
114+
# Encoding bytes to a base64 string
115+
encoded_pkl = base64.b64encode(pkl).decode('utf-8')
116+
# Serializing the base64 string
117+
serialized["vectordb"] = json.dumps(encoded_pkl)
118+
serialized["embedding_function"] = dump(obj.__dict__["embedding_function"])
119+
return serialized
120+
121+
# Mapping class to vector store serialization functions
122+
vectordb_serialization = {"OpenSearchVectorSearch": OpenSearchVectorDBSerializer, "FAISS": FaissSerializer}
123+
124+
125+
class RetrievalQASerializer:
126+
"""
127+
Serializer for RetrieverQA class
128+
"""
129+
@staticmethod
130+
def type():
131+
return "retrieval_qa"
132+
133+
@staticmethod
134+
def load(config: dict, **kwargs):
135+
config_param = deepcopy(config)
136+
retriever_kwargs = config_param.pop("retriever_kwargs")
137+
vectordb_serializer = vectordb_serialization[config_param["vectordb"]["class"]]
138+
vectordb = vectordb_serializer.load(config_param.pop("vectordb"), **kwargs)
139+
retriever = vectordb.as_retriever(**retriever_kwargs)
140+
return load_chain_from_config(config=config_param, retriever=retriever)
141+
142+
@staticmethod
143+
def save(obj):
144+
serialized = obj.dict()
145+
retriever_kwargs = {}
146+
for key, val in obj.retriever.__dict__.items():
147+
if key not in ["tags", "metadata", "vectorstore"]:
148+
retriever_kwargs[key] = val
149+
serialized["retriever_kwargs"] = retriever_kwargs
150+
serialized["vectordb"] = {"class": obj.retriever.vectorstore.__class__.__name__}
151+
152+
vectordb_serializer = vectordb_serialization[serialized["vectordb"]["class"]]
153+
serialized["vectordb"].update(
154+
vectordb_serializer.save(obj.retriever.vectorstore)
155+
)
156+
157+
if serialized["vectordb"]["class"] not in vectordb_serialization:
158+
raise NotImplementedError(
159+
f"VectorDBSerializer for {serialized['vectordb']['class']} is not implemented."
160+
)
161+
return serialized
162+
163+
42164
# Mapping class to custom serialization functions
43165
custom_serialization = {
44166
GuardrailSequence: GuardrailSequence.save,
45167
CustomGuardrailBase: CustomGuardrailBase.save,
46168
RunnableParallel: RunnableParallelSerializer.save,
169+
RetrievalQA: RetrievalQASerializer.save,
47170
}
48171

49172
# Mapping _type to custom deserialization functions
@@ -52,6 +175,7 @@ def __new_type_to_cls_dict():
52175
GuardrailSequence.type(): GuardrailSequence.load,
53176
CustomGuardrailBase.type(): CustomGuardrailBase.load,
54177
RunnableParallelSerializer.type(): RunnableParallelSerializer.load,
178+
RetrievalQASerializer.type(): RetrievalQASerializer.load,
55179
}
56180

57181

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import unittest
2+
from langchain.load.serializable import Serializable
3+
from langchain.schema.embeddings import Embeddings
4+
5+
from langchain.vectorstores import OpenSearchVectorSearch, FAISS
6+
7+
8+
import unittest
9+
from ads.llm.serialize import OpenSearchVectorDBSerializer, FaissSerializer, RetrievalQASerializer
10+
from tests.unitary.with_extras.langchain.test_guardrails import FakeLLM
11+
import os
12+
from unittest import mock
13+
from typing import Any, Dict, List, Mapping, Optional
14+
from langchain.chains import RetrievalQA
15+
from langchain import llms
16+
from langchain.llms import loading
17+
18+
19+
20+
21+
class FakeEmbeddings(Serializable, Embeddings):
22+
"""Fake LLM for testing purpose."""
23+
24+
@property
25+
def _llm_type(self) -> str:
26+
return "custom_embeddings"
27+
28+
@classmethod
29+
def is_lc_serializable(cls) -> bool:
30+
"""This class can be serialized with default LangChain serialization."""
31+
return True
32+
33+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
34+
return [[1] * 1024 for text in texts]
35+
36+
def embed_query(self, text: str) -> List[float]:
37+
return [1] * 1024
38+
39+
40+
class TestOpensearchSearchVectorSerializers(unittest.TestCase):
41+
@classmethod
42+
def setUpClass(cls):
43+
cls.env_patcher = mock.patch.dict(os.environ, {"OCI_OPENSEARCH_USERNAME": "username",
44+
"OCI_OPENSEARCH_PASSWORD": "password",
45+
"OCI_OPENSEARCH_VERIFY_CERTS": "True",
46+
"OCI_OPENSEARCH_CA_CERTS": "/path/to/cert.pem"})
47+
cls.env_patcher.start()
48+
cls.index_name = "test_index"
49+
cls.embeddings = FakeEmbeddings()
50+
cls.opensearch = OpenSearchVectorSearch(
51+
"https://localhost:8888",
52+
embedding_function=cls.embeddings,
53+
index_name=cls.index_name,
54+
engine="lucene",
55+
http_auth=(os.environ["OCI_OPENSEARCH_USERNAME"], os.environ["OCI_OPENSEARCH_PASSWORD"]),
56+
verify_certs=os.environ["OCI_OPENSEARCH_VERIFY_CERTS"],
57+
ca_certs=os.environ["OCI_OPENSEARCH_CA_CERTS"],
58+
)
59+
cls.serializer = OpenSearchVectorDBSerializer()
60+
super().setUpClass()
61+
62+
def test_type(self):
63+
# Test type()
64+
self.assertEqual(self.serializer.type(), "OpenSearchVectorSearch")
65+
66+
def test_save(self):
67+
serialized = self.serializer.save(self.opensearch)
68+
assert serialized["id"] == ['langchain', 'vectorstores', 'opensearch_vector_search', 'OpenSearchVectorSearch']
69+
assert serialized["kwargs"]["opensearch_url"] == "https://localhost:8888"
70+
assert serialized["kwargs"]["engine"] == "lucene"
71+
assert serialized["_type"] == "OpenSearchVectorSearch"
72+
73+
def test_load(self):
74+
serialized = self.serializer.save(self.opensearch)
75+
new_opensearch = self.serializer.load(serialized, valid_namespaces=["tests"])
76+
assert isinstance(new_opensearch, OpenSearchVectorSearch)
77+
78+
79+
class TestFAISSSerializers(unittest.TestCase):
80+
@classmethod
81+
def setUpClass(cls):
82+
cls.embeddings = FakeEmbeddings()
83+
text_embedding_pair = [("test", [1] * 1024)]
84+
cls.db = FAISS.from_embeddings(text_embedding_pair, cls.embeddings)
85+
cls.serializer = FaissSerializer()
86+
super().setUpClass()
87+
88+
def test_type(self):
89+
self.assertEqual(self.serializer.type(), "FAISS")
90+
91+
def test_save(self):
92+
serialized = self.serializer.save(self.db)
93+
assert serialized["embedding_function"]["id"] == ["tests", "unitary", "with_extras", "langchain", "test_serializers", "FakeEmbeddings"]
94+
assert isinstance(serialized["vectordb"], str)
95+
96+
def test_load(self):
97+
serialized = self.serializer.save(self.db)
98+
new_db = self.serializer.load(serialized, valid_namespaces=["tests"])
99+
assert isinstance(new_db, FAISS)
100+
101+
102+
class TestRetrievalQASerializer(unittest.TestCase):
103+
@classmethod
104+
def setUpClass(cls):
105+
# Create a sample RetrieverQA object for testing
106+
cls.llm = FakeLLM()
107+
cls.embeddings = FakeEmbeddings()
108+
text_embedding_pair = [("test", [1] * 1024)]
109+
cls.db = FAISS.from_embeddings(text_embedding_pair, cls.embeddings)
110+
cls.serializer = FaissSerializer()
111+
cls.retriever = cls.db.as_retriever()
112+
cls.qa = RetrievalQA.from_chain_type(llm=cls.llm,
113+
chain_type="stuff",
114+
retriever=cls.retriever)
115+
cls.serializer = RetrievalQASerializer()
116+
from copy import deepcopy
117+
cls.original_type_to_cls_dict = deepcopy(llms.get_type_to_cls_dict())
118+
__lc_llm_dict = llms.get_type_to_cls_dict()
119+
__lc_llm_dict["custom_embedding"] = lambda: FakeEmbeddings
120+
__lc_llm_dict["custom"] = lambda: FakeLLM
121+
122+
def __new_type_to_cls_dict():
123+
return __lc_llm_dict
124+
125+
llms.get_type_to_cls_dict = __new_type_to_cls_dict
126+
loading.get_type_to_cls_dict = __new_type_to_cls_dict
127+
128+
def test_type(self):
129+
self.assertEqual(self.serializer.type(), "retrieval_qa")
130+
131+
def test_save(self):
132+
# Serialize the RetrieverQA object
133+
serialized = self.serializer.save(self.qa)
134+
135+
# Ensure that the serialized object is a dictionary
136+
self.assertIsInstance(serialized, dict)
137+
138+
# Ensure that the serialized object contains the necessary keys
139+
self.assertIn("combine_documents_chain", serialized)
140+
self.assertIn("retriever_kwargs", serialized)
141+
serialized["vectordb"]["class"] == "FAISS"
142+
143+
def test_load(self):
144+
# Create a sample config dictionary
145+
serialized = self.serializer.save(self.qa)
146+
147+
# Deserialize the serialized object
148+
deserialized = self.serializer.load(serialized, valid_namespaces=["tests"])
149+
150+
# Ensure that the deserialized object is an instance of RetrieverQA
151+
self.assertIsInstance(deserialized, RetrievalQA)
152+
153+
@classmethod
154+
def tearDownClass(cls) -> None:
155+
llms.get_type_to_cls_dict = cls.original_type_to_cls_dict
156+
loading.get_type_to_cls_dict = cls.original_type_to_cls_dict
157+
return super().tearDownClass()
158+
159+
160+
if __name__ == "__main__":
161+
unittest.main()

0 commit comments

Comments
 (0)