1
- import unittest
2
- from langchain .load .serializable import Serializable
3
- from langchain .schema .embeddings import Embeddings
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*--
4
3
5
- from langchain .vectorstores import OpenSearchVectorSearch , FAISS
4
+ # Copyright (c) 2023 Oracle and/or its affiliates.
5
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
6
7
7
8
- import unittest
9
- from ads .llm .serialize import OpenSearchVectorDBSerializer , FaissSerializer , RetrievalQASerializer
10
- from tests .unitary .with_extras .langchain .test_guardrails import FakeLLM
11
8
import os
9
+ import unittest
12
10
from unittest import mock
13
- from typing import Any , Dict , List , Mapping , Optional
11
+ from typing import List
12
+ from langchain .load .serializable import Serializable
13
+ from langchain .schema .embeddings import Embeddings
14
+ from langchain .vectorstores import OpenSearchVectorSearch , FAISS
14
15
from langchain .chains import RetrievalQA
15
16
from langchain import llms
16
17
from langchain .llms import loading
17
18
18
-
19
+ from ads .llm .serializers .retrieval_qa import (
20
+ OpenSearchVectorDBSerializer ,
21
+ FaissSerializer ,
22
+ RetrievalQASerializer ,
23
+ )
24
+ from tests .unitary .with_extras .langchain .test_guardrails import FakeLLM
19
25
20
26
21
27
class FakeEmbeddings (Serializable , Embeddings ):
@@ -35,27 +41,38 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
35
41
36
42
def embed_query (self , text : str ) -> List [float ]:
37
43
return [1 ] * 1024
38
-
39
-
44
+
45
+
40
46
class TestOpensearchSearchVectorSerializers (unittest .TestCase ):
41
47
@classmethod
42
48
def setUpClass (cls ):
43
- cls .env_patcher = mock .patch .dict (os .environ , {"OCI_OPENSEARCH_USERNAME" : "username" ,
44
- "OCI_OPENSEARCH_PASSWORD" : "password" ,
45
- "OCI_OPENSEARCH_VERIFY_CERTS" : "True" ,
46
- "OCI_OPENSEARCH_CA_CERTS" : "/path/to/cert.pem" })
49
+ cls .env_patcher = mock .patch .dict (
50
+ os .environ ,
51
+ {
52
+ "OCI_OPENSEARCH_USERNAME" : "username" ,
53
+ "OCI_OPENSEARCH_PASSWORD" : "password" ,
54
+ "OCI_OPENSEARCH_VERIFY_CERTS" : "True" ,
55
+ "OCI_OPENSEARCH_CA_CERTS" : "/path/to/cert.pem" ,
56
+ },
57
+ )
47
58
cls .env_patcher .start ()
48
59
cls .index_name = "test_index"
49
60
cls .embeddings = FakeEmbeddings ()
50
- cls .opensearch = OpenSearchVectorSearch (
51
- "https://localhost:8888" ,
52
- embedding_function = cls .embeddings ,
53
- index_name = cls .index_name ,
54
- engine = "lucene" ,
55
- http_auth = (os .environ ["OCI_OPENSEARCH_USERNAME" ], os .environ ["OCI_OPENSEARCH_PASSWORD" ]),
56
- verify_certs = os .environ ["OCI_OPENSEARCH_VERIFY_CERTS" ],
57
- ca_certs = os .environ ["OCI_OPENSEARCH_CA_CERTS" ],
58
- )
61
+ try :
62
+ cls .opensearch = OpenSearchVectorSearch (
63
+ "https://localhost:8888" ,
64
+ embedding_function = cls .embeddings ,
65
+ index_name = cls .index_name ,
66
+ engine = "lucene" ,
67
+ http_auth = (
68
+ os .environ ["OCI_OPENSEARCH_USERNAME" ],
69
+ os .environ ["OCI_OPENSEARCH_PASSWORD" ],
70
+ ),
71
+ verify_certs = os .environ ["OCI_OPENSEARCH_VERIFY_CERTS" ],
72
+ ca_certs = os .environ ["OCI_OPENSEARCH_CA_CERTS" ],
73
+ )
74
+ except ImportError as ex :
75
+ raise unittest .SkipTest ("opensearch-py is not installed." ) from ex
59
76
cls .serializer = OpenSearchVectorDBSerializer ()
60
77
super ().setUpClass ()
61
78
@@ -65,7 +82,12 @@ def test_type(self):
65
82
66
83
def test_save (self ):
67
84
serialized = self .serializer .save (self .opensearch )
68
- assert serialized ["id" ] == ['langchain' , 'vectorstores' , 'opensearch_vector_search' , 'OpenSearchVectorSearch' ]
85
+ assert serialized ["id" ] == [
86
+ "langchain" ,
87
+ "vectorstores" ,
88
+ "opensearch_vector_search" ,
89
+ "OpenSearchVectorSearch" ,
90
+ ]
69
91
assert serialized ["kwargs" ]["opensearch_url" ] == "https://localhost:8888"
70
92
assert serialized ["kwargs" ]["engine" ] == "lucene"
71
93
assert serialized ["_type" ] == "OpenSearchVectorSearch"
@@ -81,7 +103,10 @@ class TestFAISSSerializers(unittest.TestCase):
81
103
def setUpClass (cls ):
82
104
cls .embeddings = FakeEmbeddings ()
83
105
text_embedding_pair = [("test" , [1 ] * 1024 )]
84
- cls .db = FAISS .from_embeddings (text_embedding_pair , cls .embeddings )
106
+ try :
107
+ cls .db = FAISS .from_embeddings (text_embedding_pair , cls .embeddings )
108
+ except ImportError as ex :
109
+ raise unittest .SkipTest (ex .msg ) from ex
85
110
cls .serializer = FaissSerializer ()
86
111
super ().setUpClass ()
87
112
@@ -90,7 +115,14 @@ def test_type(self):
90
115
91
116
def test_save (self ):
92
117
serialized = self .serializer .save (self .db )
93
- assert serialized ["embedding_function" ]["id" ] == ["tests" , "unitary" , "with_extras" , "langchain" , "test_serializers" , "FakeEmbeddings" ]
118
+ assert serialized ["embedding_function" ]["id" ] == [
119
+ "tests" ,
120
+ "unitary" ,
121
+ "with_extras" ,
122
+ "langchain" ,
123
+ "test_serializers" ,
124
+ "FakeEmbeddings" ,
125
+ ]
94
126
assert isinstance (serialized ["vectordb" ], str )
95
127
96
128
def test_load (self ):
@@ -106,14 +138,18 @@ def setUpClass(cls):
106
138
cls .llm = FakeLLM ()
107
139
cls .embeddings = FakeEmbeddings ()
108
140
text_embedding_pair = [("test" , [1 ] * 1024 )]
109
- cls .db = FAISS .from_embeddings (text_embedding_pair , cls .embeddings )
141
+ try :
142
+ cls .db = FAISS .from_embeddings (text_embedding_pair , cls .embeddings )
143
+ except ImportError as ex :
144
+ raise unittest .SkipTest (ex .msg ) from ex
110
145
cls .serializer = FaissSerializer ()
111
146
cls .retriever = cls .db .as_retriever ()
112
- cls .qa = RetrievalQA .from_chain_type (llm = cls . llm ,
113
- chain_type = "stuff" ,
114
- retriever = cls . retriever )
147
+ cls .qa = RetrievalQA .from_chain_type (
148
+ llm = cls . llm , chain_type = "stuff" , retriever = cls . retriever
149
+ )
115
150
cls .serializer = RetrievalQASerializer ()
116
151
from copy import deepcopy
152
+
117
153
cls .original_type_to_cls_dict = deepcopy (llms .get_type_to_cls_dict ())
118
154
__lc_llm_dict = llms .get_type_to_cls_dict ()
119
155
__lc_llm_dict ["custom_embedding" ] = lambda : FakeEmbeddings
@@ -158,4 +194,4 @@ def tearDownClass(cls) -> None:
158
194
159
195
160
196
if __name__ == "__main__" :
161
- unittest .main ()
197
+ unittest .main ()
0 commit comments