Skip to content

Commit 02d940f

Browse files
committed
Update test_serializers.py to avoid modifying langchain.llms.get_type_to_cls_dict
1 parent 407f3d2 commit 02d940f

File tree

1 file changed

+3
-22
lines changed

1 file changed

+3
-22
lines changed

tests/unitary/with_extras/langchain/test_serializers.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,13 @@
1313
from langchain.schema.embeddings import Embeddings
1414
from langchain.vectorstores import OpenSearchVectorSearch, FAISS
1515
from langchain.chains import RetrievalQA
16-
from langchain import llms
17-
from langchain.llms import loading
16+
from langchain.llms import Cohere
1817

1918
from ads.llm.serializers.retrieval_qa import (
2019
OpenSearchVectorDBSerializer,
2120
FaissSerializer,
2221
RetrievalQASerializer,
2322
)
24-
from tests.unitary.with_extras.langchain.test_guardrails import FakeLLM
2523

2624

2725
class FakeEmbeddings(Serializable, Embeddings):
@@ -135,7 +133,7 @@ class TestRetrievalQASerializer(unittest.TestCase):
135133
@classmethod
136134
def setUpClass(cls):
137135
# Create a sample RetrieverQA object for testing
138-
cls.llm = FakeLLM()
136+
cls.llm = Cohere(cohere_api_key="api_key")
139137
cls.embeddings = FakeEmbeddings()
140138
text_embedding_pair = [("test", [1] * 1024)]
141139
try:
@@ -148,18 +146,6 @@ def setUpClass(cls):
148146
llm=cls.llm, chain_type="stuff", retriever=cls.retriever
149147
)
150148
cls.serializer = RetrievalQASerializer()
151-
from copy import deepcopy
152-
153-
cls.original_type_to_cls_dict = deepcopy(llms.get_type_to_cls_dict())
154-
__lc_llm_dict = llms.get_type_to_cls_dict()
155-
__lc_llm_dict["custom_embedding"] = lambda: FakeEmbeddings
156-
__lc_llm_dict["custom"] = lambda: FakeLLM
157-
158-
def __new_type_to_cls_dict():
159-
return __lc_llm_dict
160-
161-
llms.get_type_to_cls_dict = __new_type_to_cls_dict
162-
loading.get_type_to_cls_dict = __new_type_to_cls_dict
163149

164150
def test_type(self):
165151
self.assertEqual(self.serializer.type(), "retrieval_qa")
@@ -176,6 +162,7 @@ def test_save(self):
176162
self.assertIn("retriever_kwargs", serialized)
177163
serialized["vectordb"]["class"] == "FAISS"
178164

165+
@mock.patch.dict(os.environ, {"COHERE_API_KEY": "api_key"})
179166
def test_load(self):
180167
# Create a sample config dictionary
181168
serialized = self.serializer.save(self.qa)
@@ -186,12 +173,6 @@ def test_load(self):
186173
# Ensure that the deserialized object is an instance of RetrieverQA
187174
self.assertIsInstance(deserialized, RetrievalQA)
188175

189-
@classmethod
190-
def tearDownClass(cls) -> None:
191-
llms.get_type_to_cls_dict = cls.original_type_to_cls_dict
192-
loading.get_type_to_cls_dict = cls.original_type_to_cls_dict
193-
return super().tearDownClass()
194-
195176

196177
if __name__ == "__main__":
197178
unittest.main()

0 commit comments

Comments
 (0)