Skip to content

Commit 8b83df0

Browse files
committed
adding unit test
1 parent 9353b2e commit 8b83df0

File tree

2 files changed

+171
-12
lines changed

2 files changed

+171
-12
lines changed

ads/llm/serialize.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Copyright (c) 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7+
import base64
78
import json
89
import os
910
import tempfile
@@ -19,18 +20,15 @@
1920
from langchain.load import dumpd
2021
from langchain.load.load import load as lc_load
2122
from langchain.load.serializable import Serializable
22-
from langchain.vectorstores import OpenSearchVectorSearch, FAISS
23+
from langchain.vectorstores import FAISS, OpenSearchVectorSearch
2324
from opensearchpy.client import OpenSearch
2425

2526
from ads.common.auth import default_signer
2627
from ads.common.object_storage_details import ObjectStorageDetails
27-
from ads.llm import GenerativeAI, ModelDeploymentVLLM, ModelDeploymentTGI
28+
from ads.llm import GenerativeAI, ModelDeploymentTGI, ModelDeploymentVLLM
2829
from ads.llm.chain import GuardrailSequence
2930
from ads.llm.guardrails.base import CustomGuardrailBase
3031
from ads.llm.patch import RunnableParallel, RunnableParallelSerializer
31-
import json
32-
import base64
33-
3432

3533
# This is a temp solution for supporting custom LLM in legacy load_chain
3634
__lc_llm_dict = llms.get_type_to_cls_dict()
@@ -58,15 +56,15 @@ def type():
5856
@staticmethod
5957
def load(config: dict, **kwargs):
6058
config["kwargs"]["embedding_function"] = load(
61-
config["kwargs"]["embedding_function"]
59+
config["kwargs"]["embedding_function"], **kwargs
6260
)
6361
return OpenSearchVectorSearch(
6462
**config["kwargs"],
6563
http_auth=(
6664
os.environ.get("oci_opensearch_username", None),
6765
os.environ.get("oci_opensearch_password", None),
6866
),
69-
verify_certs=os.environ.get("oci_opensearch_verify_certs", False),
67+
verify_certs=True if os.environ.get("oci_opensearch_verify_certs", None).lower() == "true" else False,
7068
ca_certs=os.environ.get("oci_opensearch_ca_certs", None),
7169
)
7270

@@ -102,7 +100,7 @@ def type():
102100

103101
@staticmethod
104102
def load(config: dict, **kwargs):
105-
embedding_function = load(config["embedding_function"])
103+
embedding_function = load(config["embedding_function"], **kwargs)
106104
decoded_pkl = base64.b64decode(json.loads(config["vectordb"]))
107105
return FAISS.deserialize_from_bytes(
108106
embeddings=embedding_function, serialized=decoded_pkl
@@ -124,7 +122,7 @@ def save(obj):
124122
vectordb_serialization = {"OpenSearchVectorSearch": OpenSearchVectorDBSerializer, "FAISS": FaissSerializer}
125123

126124

127-
class RetrieverQASerializer:
125+
class RetrievalQASerializer:
128126
"""
129127
Serializer for RetrieverQA class
130128
"""
@@ -137,7 +135,7 @@ def load(config: dict, **kwargs):
137135
config_param = deepcopy(config)
138136
retriever_kwargs = config_param.pop("retriever_kwargs")
139137
vectordb_serializer = vectordb_serialization[config_param["vectordb"]["class"]]
140-
vectordb = vectordb_serializer.load(config_param.pop("vectordb"))
138+
vectordb = vectordb_serializer.load(config_param.pop("vectordb"), **kwargs)
141139
retriever = vectordb.as_retriever(**retriever_kwargs)
142140
return load_chain_from_config(config=config_param, retriever=retriever)
143141

@@ -168,7 +166,7 @@ def save(obj):
168166
GuardrailSequence: GuardrailSequence.save,
169167
CustomGuardrailBase: CustomGuardrailBase.save,
170168
RunnableParallel: RunnableParallelSerializer.save,
171-
RetrievalQA: RetrieverQASerializer.save,
169+
RetrievalQA: RetrievalQASerializer.save,
172170
}
173171

174172
# Mapping _type to custom deserialization functions
@@ -177,7 +175,7 @@ def save(obj):
177175
GuardrailSequence.type(): GuardrailSequence.load,
178176
CustomGuardrailBase.type(): CustomGuardrailBase.load,
179177
RunnableParallelSerializer.type(): RunnableParallelSerializer.load,
180-
RetrieverQASerializer.type(): RetrieverQASerializer.load,
178+
RetrievalQASerializer.type(): RetrievalQASerializer.load,
181179
}
182180

183181

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import unittest
2+
from langchain.load.serializable import Serializable
3+
from langchain.schema.embeddings import Embeddings
4+
5+
from langchain.vectorstores import OpenSearchVectorSearch, FAISS
6+
7+
8+
import unittest
9+
from ads.llm.serialize import OpenSearchVectorDBSerializer, FaissSerializer, RetrievalQASerializer
10+
from tests.unitary.with_extras.langchain.test_guardrails import FakeLLM
11+
import os
12+
from unittest import mock
13+
from typing import Any, Dict, List, Mapping, Optional
14+
from langchain.chains import RetrievalQA
15+
from langchain import llms
16+
from langchain.llms import loading
17+
18+
19+
20+
21+
class FakeEmbeddings(Serializable, Embeddings):
22+
"""Fake LLM for testing purpose."""
23+
24+
@property
25+
def _llm_type(self) -> str:
26+
return "custom_embeddings"
27+
28+
@classmethod
29+
def is_lc_serializable(cls) -> bool:
30+
"""This class can be serialized with default LangChain serialization."""
31+
return True
32+
33+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
34+
return [[1] * 1024 for text in texts]
35+
36+
def embed_query(self, text: str) -> List[float]:
37+
return [1] * 1024
38+
39+
40+
class TestOpensearchSearchVectorSerializers(unittest.TestCase):
41+
@classmethod
42+
def setUpClass(cls):
43+
cls.env_patcher = mock.patch.dict(os.environ, {"oci_opensearch_username": "username",
44+
"oci_opensearch_password": "password",
45+
"oci_opensearch_verify_certs": "True",
46+
"oci_opensearch_ca_certs": "/path/to/cert.pem"})
47+
cls.env_patcher.start()
48+
cls.index_name = "test_index"
49+
cls.embeddings = FakeEmbeddings()
50+
cls.opensearch = OpenSearchVectorSearch(
51+
"https://localhost:8888",
52+
embedding_function=cls.embeddings,
53+
index_name=cls.index_name,
54+
engine="lucene",
55+
http_auth=(os.environ["oci_opensearch_username"], os.environ["oci_opensearch_password"]),
56+
verify_certs=os.environ["oci_opensearch_verify_certs"],
57+
ca_certs=os.environ["oci_opensearch_ca_certs"],
58+
)
59+
cls.serializer = OpenSearchVectorDBSerializer()
60+
super().setUpClass()
61+
62+
def test_type(self):
63+
# Test type()
64+
self.assertEqual(self.serializer.type(), "OpenSearchVectorSearch")
65+
66+
def test_save(self):
67+
serialized = self.serializer.save(self.opensearch)
68+
assert serialized["id"] == ['langchain', 'vectorstores', 'opensearch_vector_search', 'OpenSearchVectorSearch']
69+
assert serialized["kwargs"]["opensearch_url"] == "https://localhost:8888"
70+
assert serialized["kwargs"]["engine"] == "lucene"
71+
assert serialized["_type"] == "OpenSearchVectorSearch"
72+
73+
def test_load(self):
74+
serialized = self.serializer.save(self.opensearch)
75+
new_opensearch = self.serializer.load(serialized, valid_namespaces=["tests"])
76+
assert isinstance(new_opensearch, OpenSearchVectorSearch)
77+
78+
79+
class TestFAISSSerializers(unittest.TestCase):
80+
@classmethod
81+
def setUpClass(cls):
82+
cls.embeddings = FakeEmbeddings()
83+
text_embedding_pair = [("test", [1] * 1024)]
84+
cls.db = FAISS.from_embeddings(text_embedding_pair, cls.embeddings)
85+
cls.serializer = FaissSerializer()
86+
super().setUpClass()
87+
88+
def test_type(self):
89+
self.assertEqual(self.serializer.type(), "FAISS")
90+
91+
def test_save(self):
92+
serialized = self.serializer.save(self.db)
93+
assert serialized["embedding_function"]["id"] == ["tests", "unitary", "with_extras", "langchain", "test_serializers", "FakeEmbeddings"]
94+
assert isinstance(serialized["vectordb"], str)
95+
96+
def test_load(self):
97+
serialized = self.serializer.save(self.db)
98+
new_db = self.serializer.load(serialized, valid_namespaces=["tests"])
99+
assert isinstance(new_db, FAISS)
100+
101+
102+
class TestRetrievalQASerializer(unittest.TestCase):
103+
@classmethod
104+
def setUpClass(cls):
105+
# Create a sample RetrieverQA object for testing
106+
cls.llm = FakeLLM()
107+
cls.embeddings = FakeEmbeddings()
108+
text_embedding_pair = [("test", [1] * 1024)]
109+
cls.db = FAISS.from_embeddings(text_embedding_pair, cls.embeddings)
110+
cls.serializer = FaissSerializer()
111+
cls.retriever = cls.db.as_retriever()
112+
cls.qa = RetrievalQA.from_chain_type(llm=cls.llm,
113+
chain_type="stuff",
114+
retriever=cls.retriever)
115+
cls.serializer = RetrievalQASerializer()
116+
from copy import deepcopy
117+
cls.original_type_to_cls_dict = deepcopy(llms.get_type_to_cls_dict())
118+
__lc_llm_dict = llms.get_type_to_cls_dict()
119+
__lc_llm_dict["custom_embedding"] = lambda: FakeEmbeddings
120+
__lc_llm_dict["custom"] = lambda: FakeLLM
121+
122+
def __new_type_to_cls_dict():
123+
return __lc_llm_dict
124+
125+
llms.get_type_to_cls_dict = __new_type_to_cls_dict
126+
loading.get_type_to_cls_dict = __new_type_to_cls_dict
127+
128+
def test_type(self):
129+
self.assertEqual(self.serializer.type(), "retrieval_qa")
130+
131+
def test_save(self):
132+
# Serialize the RetrieverQA object
133+
serialized = self.serializer.save(self.qa)
134+
135+
# Ensure that the serialized object is a dictionary
136+
self.assertIsInstance(serialized, dict)
137+
138+
# Ensure that the serialized object contains the necessary keys
139+
self.assertIn("combine_documents_chain", serialized)
140+
self.assertIn("retriever_kwargs", serialized)
141+
serialized["vectordb"]["class"] == "FAISS"
142+
143+
def test_load(self):
144+
# Create a sample config dictionary
145+
serialized = self.serializer.save(self.qa)
146+
147+
# Deserialize the serialized object
148+
deserialized = self.serializer.load(serialized, valid_namespaces=["tests"])
149+
150+
# Ensure that the deserialized object is an instance of RetrieverQA
151+
self.assertIsInstance(deserialized, RetrievalQA)
152+
153+
@classmethod
154+
def tearDownClass(cls) -> None:
155+
llms.get_type_to_cls_dict = cls.original_type_to_cls_dict
156+
loading.get_type_to_cls_dict = cls.original_type_to_cls_dict
157+
return super().tearDownClass()
158+
159+
160+
if __name__ == "__main__":
161+
unittest.main()

0 commit comments

Comments
 (0)