13
13
from langchain .schema .embeddings import Embeddings
14
14
from langchain .vectorstores import OpenSearchVectorSearch , FAISS
15
15
from langchain .chains import RetrievalQA
16
- from langchain import llms
17
- from langchain .llms import loading
16
+ from langchain .llms import Cohere
18
17
19
18
from ads .llm .serializers .retrieval_qa import (
20
19
OpenSearchVectorDBSerializer ,
21
20
FaissSerializer ,
22
21
RetrievalQASerializer ,
23
22
)
24
- from tests .unitary .with_extras .langchain .test_guardrails import FakeLLM
25
23
26
24
27
25
class FakeEmbeddings (Serializable , Embeddings ):
@@ -135,7 +133,7 @@ class TestRetrievalQASerializer(unittest.TestCase):
135
133
@classmethod
136
134
def setUpClass (cls ):
137
135
# Create a sample RetrieverQA object for testing
138
- cls .llm = FakeLLM ( )
136
+ cls .llm = Cohere ( cohere_api_key = "api_key" )
139
137
cls .embeddings = FakeEmbeddings ()
140
138
text_embedding_pair = [("test" , [1 ] * 1024 )]
141
139
try :
@@ -148,18 +146,6 @@ def setUpClass(cls):
148
146
llm = cls .llm , chain_type = "stuff" , retriever = cls .retriever
149
147
)
150
148
cls .serializer = RetrievalQASerializer ()
151
- from copy import deepcopy
152
-
153
- cls .original_type_to_cls_dict = deepcopy (llms .get_type_to_cls_dict ())
154
- __lc_llm_dict = llms .get_type_to_cls_dict ()
155
- __lc_llm_dict ["custom_embedding" ] = lambda : FakeEmbeddings
156
- __lc_llm_dict ["custom" ] = lambda : FakeLLM
157
-
158
- def __new_type_to_cls_dict ():
159
- return __lc_llm_dict
160
-
161
- llms .get_type_to_cls_dict = __new_type_to_cls_dict
162
- loading .get_type_to_cls_dict = __new_type_to_cls_dict
163
149
164
150
def test_type (self ):
165
151
self .assertEqual (self .serializer .type (), "retrieval_qa" )
@@ -176,6 +162,7 @@ def test_save(self):
176
162
self .assertIn ("retriever_kwargs" , serialized )
177
163
serialized ["vectordb" ]["class" ] == "FAISS"
178
164
165
+ @mock .patch .dict (os .environ , {"COHERE_API_KEY" : "api_key" })
179
166
def test_load (self ):
180
167
# Create a sample config dictionary
181
168
serialized = self .serializer .save (self .qa )
@@ -186,12 +173,6 @@ def test_load(self):
186
173
# Ensure that the deserialized object is an instance of RetrieverQA
187
174
self .assertIsInstance (deserialized , RetrievalQA )
188
175
189
- @classmethod
190
- def tearDownClass (cls ) -> None :
191
- llms .get_type_to_cls_dict = cls .original_type_to_cls_dict
192
- loading .get_type_to_cls_dict = cls .original_type_to_cls_dict
193
- return super ().tearDownClass ()
194
-
195
176
196
177
if __name__ == "__main__" :
197
178
unittest .main ()
0 commit comments