1
+ import unittest
2
+ from langchain .load .serializable import Serializable
3
+ from langchain .schema .embeddings import Embeddings
4
+
5
+ from langchain .vectorstores import OpenSearchVectorSearch , FAISS
6
+
7
+
8
+ import unittest
9
+ from ads .llm .serialize import OpenSearchVectorDBSerializer , FaissSerializer , RetrievalQASerializer
10
+ from tests .unitary .with_extras .langchain .test_guardrails import FakeLLM
11
+ import os
12
+ from unittest import mock
13
+ from typing import Any , Dict , List , Mapping , Optional
14
+ from langchain .chains import RetrievalQA
15
+ from langchain import llms
16
+ from langchain .llms import loading
17
+
18
+
19
+
20
+
21
+ class FakeEmbeddings (Serializable , Embeddings ):
22
+ """Fake LLM for testing purpose."""
23
+
24
+ @property
25
+ def _llm_type (self ) -> str :
26
+ return "custom_embeddings"
27
+
28
+ @classmethod
29
+ def is_lc_serializable (cls ) -> bool :
30
+ """This class can be serialized with default LangChain serialization."""
31
+ return True
32
+
33
+ def embed_documents (self , texts : List [str ]) -> List [List [float ]]:
34
+ return [[1 ] * 1024 for text in texts ]
35
+
36
+ def embed_query (self , text : str ) -> List [float ]:
37
+ return [1 ] * 1024
38
+
39
+
40
+ class TestOpensearchSearchVectorSerializers (unittest .TestCase ):
41
+ @classmethod
42
+ def setUpClass (cls ):
43
+ cls .env_patcher = mock .patch .dict (os .environ , {"oci_opensearch_username" : "username" ,
44
+ "oci_opensearch_password" : "password" ,
45
+ "oci_opensearch_verify_certs" : "True" ,
46
+ "oci_opensearch_ca_certs" : "/path/to/cert.pem" })
47
+ cls .env_patcher .start ()
48
+ cls .index_name = "test_index"
49
+ cls .embeddings = FakeEmbeddings ()
50
+ cls .opensearch = OpenSearchVectorSearch (
51
+ "https://localhost:8888" ,
52
+ embedding_function = cls .embeddings ,
53
+ index_name = cls .index_name ,
54
+ engine = "lucene" ,
55
+ http_auth = (os .environ ["oci_opensearch_username" ], os .environ ["oci_opensearch_password" ]),
56
+ verify_certs = os .environ ["oci_opensearch_verify_certs" ],
57
+ ca_certs = os .environ ["oci_opensearch_ca_certs" ],
58
+ )
59
+ cls .serializer = OpenSearchVectorDBSerializer ()
60
+ super ().setUpClass ()
61
+
62
+ def test_type (self ):
63
+ # Test type()
64
+ self .assertEqual (self .serializer .type (), "OpenSearchVectorSearch" )
65
+
66
+ def test_save (self ):
67
+ serialized = self .serializer .save (self .opensearch )
68
+ assert serialized ["id" ] == ['langchain' , 'vectorstores' , 'opensearch_vector_search' , 'OpenSearchVectorSearch' ]
69
+ assert serialized ["kwargs" ]["opensearch_url" ] == "https://localhost:8888"
70
+ assert serialized ["kwargs" ]["engine" ] == "lucene"
71
+ assert serialized ["_type" ] == "OpenSearchVectorSearch"
72
+
73
+ def test_load (self ):
74
+ serialized = self .serializer .save (self .opensearch )
75
+ new_opensearch = self .serializer .load (serialized , valid_namespaces = ["tests" ])
76
+ assert isinstance (new_opensearch , OpenSearchVectorSearch )
77
+
78
+
79
+ class TestFAISSSerializers (unittest .TestCase ):
80
+ @classmethod
81
+ def setUpClass (cls ):
82
+ cls .embeddings = FakeEmbeddings ()
83
+ text_embedding_pair = [("test" , [1 ] * 1024 )]
84
+ cls .db = FAISS .from_embeddings (text_embedding_pair , cls .embeddings )
85
+ cls .serializer = FaissSerializer ()
86
+ super ().setUpClass ()
87
+
88
+ def test_type (self ):
89
+ self .assertEqual (self .serializer .type (), "FAISS" )
90
+
91
+ def test_save (self ):
92
+ serialized = self .serializer .save (self .db )
93
+ assert serialized ["embedding_function" ]["id" ] == ["tests" , "unitary" , "with_extras" , "langchain" , "test_serializers" , "FakeEmbeddings" ]
94
+ assert isinstance (serialized ["vectordb" ], str )
95
+
96
+ def test_load (self ):
97
+ serialized = self .serializer .save (self .db )
98
+ new_db = self .serializer .load (serialized , valid_namespaces = ["tests" ])
99
+ assert isinstance (new_db , FAISS )
100
+
101
+
102
+ class TestRetrievalQASerializer (unittest .TestCase ):
103
+ @classmethod
104
+ def setUpClass (cls ):
105
+ # Create a sample RetrieverQA object for testing
106
+ cls .llm = FakeLLM ()
107
+ cls .embeddings = FakeEmbeddings ()
108
+ text_embedding_pair = [("test" , [1 ] * 1024 )]
109
+ cls .db = FAISS .from_embeddings (text_embedding_pair , cls .embeddings )
110
+ cls .serializer = FaissSerializer ()
111
+ cls .retriever = cls .db .as_retriever ()
112
+ cls .qa = RetrievalQA .from_chain_type (llm = cls .llm ,
113
+ chain_type = "stuff" ,
114
+ retriever = cls .retriever )
115
+ cls .serializer = RetrievalQASerializer ()
116
+ from copy import deepcopy
117
+ cls .original_type_to_cls_dict = deepcopy (llms .get_type_to_cls_dict ())
118
+ __lc_llm_dict = llms .get_type_to_cls_dict ()
119
+ __lc_llm_dict ["custom_embedding" ] = lambda : FakeEmbeddings
120
+ __lc_llm_dict ["custom" ] = lambda : FakeLLM
121
+
122
+ def __new_type_to_cls_dict ():
123
+ return __lc_llm_dict
124
+
125
+ llms .get_type_to_cls_dict = __new_type_to_cls_dict
126
+ loading .get_type_to_cls_dict = __new_type_to_cls_dict
127
+
128
+ def test_type (self ):
129
+ self .assertEqual (self .serializer .type (), "retrieval_qa" )
130
+
131
+ def test_save (self ):
132
+ # Serialize the RetrieverQA object
133
+ serialized = self .serializer .save (self .qa )
134
+
135
+ # Ensure that the serialized object is a dictionary
136
+ self .assertIsInstance (serialized , dict )
137
+
138
+ # Ensure that the serialized object contains the necessary keys
139
+ self .assertIn ("combine_documents_chain" , serialized )
140
+ self .assertIn ("retriever_kwargs" , serialized )
141
+ serialized ["vectordb" ]["class" ] == "FAISS"
142
+
143
+ def test_load (self ):
144
+ # Create a sample config dictionary
145
+ serialized = self .serializer .save (self .qa )
146
+
147
+ # Deserialize the serialized object
148
+ deserialized = self .serializer .load (serialized , valid_namespaces = ["tests" ])
149
+
150
+ # Ensure that the deserialized object is an instance of RetrieverQA
151
+ self .assertIsInstance (deserialized , RetrievalQA )
152
+
153
+ @classmethod
154
+ def tearDownClass (cls ) -> None :
155
+ llms .get_type_to_cls_dict = cls .original_type_to_cls_dict
156
+ loading .get_type_to_cls_dict = cls .original_type_to_cls_dict
157
+ return super ().tearDownClass ()
158
+
159
+
160
+ if __name__ == "__main__" :
161
+ unittest .main ()
0 commit comments