Skip to content

Commit 486070c

Browse files
committed
Update test_serializers.py to skip test if dependencies are not installed.
1 parent 3d02a35 commit 486070c

File tree

1 file changed

+68
-32
lines changed

1 file changed

+68
-32
lines changed

tests/unitary/with_extras/langchain/test_serializers.py

Lines changed: 68 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,27 @@
1-
import unittest
2-
from langchain.load.serializable import Serializable
3-
from langchain.schema.embeddings import Embeddings
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
43

5-
from langchain.vectorstores import OpenSearchVectorSearch, FAISS
4+
# Copyright (c) 2023 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77

8-
import unittest
9-
from ads.llm.serialize import OpenSearchVectorDBSerializer, FaissSerializer, RetrievalQASerializer
10-
from tests.unitary.with_extras.langchain.test_guardrails import FakeLLM
118
import os
9+
import unittest
1210
from unittest import mock
13-
from typing import Any, Dict, List, Mapping, Optional
11+
from typing import List
12+
from langchain.load.serializable import Serializable
13+
from langchain.schema.embeddings import Embeddings
14+
from langchain.vectorstores import OpenSearchVectorSearch, FAISS
1415
from langchain.chains import RetrievalQA
1516
from langchain import llms
1617
from langchain.llms import loading
1718

18-
19+
from ads.llm.serializers.retrieval_qa import (
20+
OpenSearchVectorDBSerializer,
21+
FaissSerializer,
22+
RetrievalQASerializer,
23+
)
24+
from tests.unitary.with_extras.langchain.test_guardrails import FakeLLM
1925

2026

2127
class FakeEmbeddings(Serializable, Embeddings):
@@ -35,27 +41,38 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
3541

3642
def embed_query(self, text: str) -> List[float]:
3743
return [1] * 1024
38-
39-
44+
45+
4046
class TestOpensearchSearchVectorSerializers(unittest.TestCase):
4147
@classmethod
4248
def setUpClass(cls):
43-
cls.env_patcher = mock.patch.dict(os.environ, {"OCI_OPENSEARCH_USERNAME": "username",
44-
"OCI_OPENSEARCH_PASSWORD": "password",
45-
"OCI_OPENSEARCH_VERIFY_CERTS": "True",
46-
"OCI_OPENSEARCH_CA_CERTS": "/path/to/cert.pem"})
49+
cls.env_patcher = mock.patch.dict(
50+
os.environ,
51+
{
52+
"OCI_OPENSEARCH_USERNAME": "username",
53+
"OCI_OPENSEARCH_PASSWORD": "password",
54+
"OCI_OPENSEARCH_VERIFY_CERTS": "True",
55+
"OCI_OPENSEARCH_CA_CERTS": "/path/to/cert.pem",
56+
},
57+
)
4758
cls.env_patcher.start()
4859
cls.index_name = "test_index"
4960
cls.embeddings = FakeEmbeddings()
50-
cls.opensearch = OpenSearchVectorSearch(
51-
"https://localhost:8888",
52-
embedding_function=cls.embeddings,
53-
index_name=cls.index_name,
54-
engine="lucene",
55-
http_auth=(os.environ["OCI_OPENSEARCH_USERNAME"], os.environ["OCI_OPENSEARCH_PASSWORD"]),
56-
verify_certs=os.environ["OCI_OPENSEARCH_VERIFY_CERTS"],
57-
ca_certs=os.environ["OCI_OPENSEARCH_CA_CERTS"],
58-
)
61+
try:
62+
cls.opensearch = OpenSearchVectorSearch(
63+
"https://localhost:8888",
64+
embedding_function=cls.embeddings,
65+
index_name=cls.index_name,
66+
engine="lucene",
67+
http_auth=(
68+
os.environ["OCI_OPENSEARCH_USERNAME"],
69+
os.environ["OCI_OPENSEARCH_PASSWORD"],
70+
),
71+
verify_certs=os.environ["OCI_OPENSEARCH_VERIFY_CERTS"],
72+
ca_certs=os.environ["OCI_OPENSEARCH_CA_CERTS"],
73+
)
74+
except ImportError as ex:
75+
raise unittest.SkipTest("opensearch-py is not installed.") from ex
5976
cls.serializer = OpenSearchVectorDBSerializer()
6077
super().setUpClass()
6178

@@ -65,7 +82,12 @@ def test_type(self):
6582

6683
def test_save(self):
6784
serialized = self.serializer.save(self.opensearch)
68-
assert serialized["id"] == ['langchain', 'vectorstores', 'opensearch_vector_search', 'OpenSearchVectorSearch']
85+
assert serialized["id"] == [
86+
"langchain",
87+
"vectorstores",
88+
"opensearch_vector_search",
89+
"OpenSearchVectorSearch",
90+
]
6991
assert serialized["kwargs"]["opensearch_url"] == "https://localhost:8888"
7092
assert serialized["kwargs"]["engine"] == "lucene"
7193
assert serialized["_type"] == "OpenSearchVectorSearch"
@@ -81,7 +103,10 @@ class TestFAISSSerializers(unittest.TestCase):
81103
def setUpClass(cls):
82104
cls.embeddings = FakeEmbeddings()
83105
text_embedding_pair = [("test", [1] * 1024)]
84-
cls.db = FAISS.from_embeddings(text_embedding_pair, cls.embeddings)
106+
try:
107+
cls.db = FAISS.from_embeddings(text_embedding_pair, cls.embeddings)
108+
except ImportError as ex:
109+
raise unittest.SkipTest(ex.msg) from ex
85110
cls.serializer = FaissSerializer()
86111
super().setUpClass()
87112

@@ -90,7 +115,14 @@ def test_type(self):
90115

91116
def test_save(self):
92117
serialized = self.serializer.save(self.db)
93-
assert serialized["embedding_function"]["id"] == ["tests", "unitary", "with_extras", "langchain", "test_serializers", "FakeEmbeddings"]
118+
assert serialized["embedding_function"]["id"] == [
119+
"tests",
120+
"unitary",
121+
"with_extras",
122+
"langchain",
123+
"test_serializers",
124+
"FakeEmbeddings",
125+
]
94126
assert isinstance(serialized["vectordb"], str)
95127

96128
def test_load(self):
@@ -106,14 +138,18 @@ def setUpClass(cls):
106138
cls.llm = FakeLLM()
107139
cls.embeddings = FakeEmbeddings()
108140
text_embedding_pair = [("test", [1] * 1024)]
109-
cls.db = FAISS.from_embeddings(text_embedding_pair, cls.embeddings)
141+
try:
142+
cls.db = FAISS.from_embeddings(text_embedding_pair, cls.embeddings)
143+
except ImportError as ex:
144+
raise unittest.SkipTest(ex.msg) from ex
110145
cls.serializer = FaissSerializer()
111146
cls.retriever = cls.db.as_retriever()
112-
cls.qa = RetrievalQA.from_chain_type(llm=cls.llm,
113-
chain_type="stuff",
114-
retriever=cls.retriever)
147+
cls.qa = RetrievalQA.from_chain_type(
148+
llm=cls.llm, chain_type="stuff", retriever=cls.retriever
149+
)
115150
cls.serializer = RetrievalQASerializer()
116151
from copy import deepcopy
152+
117153
cls.original_type_to_cls_dict = deepcopy(llms.get_type_to_cls_dict())
118154
__lc_llm_dict = llms.get_type_to_cls_dict()
119155
__lc_llm_dict["custom_embedding"] = lambda: FakeEmbeddings
@@ -158,4 +194,4 @@ def tearDownClass(cls) -> None:
158194

159195

160196
if __name__ == "__main__":
161-
unittest.main()
197+
unittest.main()

0 commit comments

Comments
 (0)