Skip to content

Commit 68c9ec4

Browse files
authored
Update LangChain Tests (#498)
2 parents 72509c9 + 258992e commit 68c9ec4

File tree

3 files changed

+28
-35
lines changed

3 files changed

+28
-35
lines changed

ads/jobs/templates/driver_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ def run_command(
413413
logger.log(level=level, msg=msg)
414414
# Add a small delay so that
415415
# outputs from the subsequent code will have different timestamp for oci logging
416-
time.sleep(0.05)
416+
time.sleep(0.02)
417417
if check and process.returncode != 0:
418418
# If there is an error, exit the main process with the same return code.
419419
sys.exit(process.returncode)

tests/unitary/with_extras/langchain/test_serialization.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77

88
import os
9+
from copy import deepcopy
910
from unittest import TestCase, mock, SkipTest
1011

1112
from langchain.llms import Cohere
@@ -25,6 +26,11 @@
2526
class ChainSerializationTest(TestCase):
2627
"""Contains tests for chain serialization."""
2728

29+
# LangChain is updating frequently on the module organization,
30+
# mainly affecting the id field of the serialization.
31+
# In the test, we will not check the id field of some components.
32+
# We expect users to use the same LangChain version for serialize and de-serialize
33+
2834
def setUp(self) -> None:
2935
self.maxDiff = None
3036
return super().setUp()
@@ -75,7 +81,6 @@ def setUp(self) -> None:
7581
"prompt": {
7682
"lc": 1,
7783
"type": "constructor",
78-
"id": ["langchain_core", "prompts", "prompt", "PromptTemplate"],
7984
"kwargs": {
8085
"input_variables": ["subject"],
8186
"template": "Tell me a joke about {subject}",
@@ -118,12 +123,10 @@ def setUp(self) -> None:
118123
EXPECTED_RUNNABLE_SEQUENCE = {
119124
"lc": 1,
120125
"type": "constructor",
121-
"id": ["langchain_core", "runnables", "RunnableSequence"],
122126
"kwargs": {
123127
"first": {
124128
"lc": 1,
125129
"type": "constructor",
126-
"id": ["langchain_core", "runnables", "RunnableParallel"],
127130
"kwargs": {
128131
"steps": {
129132
"text": {
@@ -144,7 +147,6 @@ def setUp(self) -> None:
144147
{
145148
"lc": 1,
146149
"type": "constructor",
147-
"id": ["langchain_core", "prompts", "prompt", "PromptTemplate"],
148150
"kwargs": {
149151
"input_variables": ["subject"],
150152
"template": "Tell me a joke about {subject}",
@@ -185,7 +187,10 @@ def test_llm_chain_serialization_with_oci(self):
185187
template = PromptTemplate.from_template(self.PROMPT_TEMPLATE)
186188
llm_chain = LLMChain(prompt=template, llm=llm)
187189
serialized = dump(llm_chain)
188-
self.assertEqual(serialized, self.EXPECTED_LLM_CHAIN_WITH_OCI_MD)
190+
# Do not check the ID field.
191+
expected = deepcopy(self.EXPECTED_LLM_CHAIN_WITH_OCI_MD)
192+
expected["kwargs"]["prompt"]["id"] = serialized["kwargs"]["prompt"]["id"]
193+
self.assertEqual(serialized, expected)
189194
llm_chain = load(serialized)
190195
self.assertIsInstance(llm_chain, LLMChain)
191196
self.assertIsInstance(llm_chain.prompt, PromptTemplate)
@@ -202,8 +207,8 @@ def test_oci_gen_ai_serialization(self):
202207
compartment_id=self.COMPARTMENT_ID,
203208
client_kwargs=self.GEN_AI_KWARGS,
204209
)
205-
except ImportError:
206-
raise SkipTest("OCI SDK does not support Generative AI.")
210+
except ImportError as ex:
211+
raise SkipTest("OCI SDK does not support Generative AI.") from ex
207212
serialized = dump(llm)
208213
self.assertEqual(serialized, self.EXPECTED_GEN_AI_LLM)
209214
llm = load(serialized)
@@ -216,8 +221,8 @@ def test_gen_ai_embeddings_serialization(self):
216221
embeddings = GenerativeAIEmbeddings(
217222
compartment_id=self.COMPARTMENT_ID, client_kwargs=self.GEN_AI_KWARGS
218223
)
219-
except ImportError:
220-
raise SkipTest("OCI SDK does not support Generative AI.")
224+
except ImportError as ex:
225+
raise SkipTest("OCI SDK does not support Generative AI.") from ex
221226
serialized = dump(embeddings)
222227
self.assertEqual(serialized, self.EXPECTED_GEN_AI_EMBEDDINGS)
223228
embeddings = load(serialized)
@@ -232,7 +237,15 @@ def test_runnable_sequence_serialization(self):
232237

233238
chain = map_input | template | llm
234239
serialized = dump(chain)
235-
self.assertEqual(serialized, self.EXPECTED_RUNNABLE_SEQUENCE)
240+
# Do not check the ID fields.
241+
expected = deepcopy(self.EXPECTED_RUNNABLE_SEQUENCE)
242+
expected["id"] = serialized["id"]
243+
expected["kwargs"]["first"]["id"] = serialized["kwargs"]["first"]["id"]
244+
expected["kwargs"]["first"]["kwargs"]["steps"]["text"]["id"] = serialized[
245+
"kwargs"
246+
]["first"]["kwargs"]["steps"]["text"]["id"]
247+
expected["kwargs"]["middle"][0]["id"] = serialized["kwargs"]["middle"][0]["id"]
248+
self.assertEqual(serialized, expected)
236249
chain = load(serialized)
237250
self.assertEqual(len(chain.steps), 3)
238251
self.assertIsInstance(chain.steps[0], RunnableParallel)

tests/unitary/with_extras/langchain/test_serializers.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,13 @@
1313
from langchain.schema.embeddings import Embeddings
1414
from langchain.vectorstores import OpenSearchVectorSearch, FAISS
1515
from langchain.chains import RetrievalQA
16-
from langchain import llms
17-
from langchain.llms import loading
16+
from langchain.llms import Cohere
1817

1918
from ads.llm.serializers.retrieval_qa import (
2019
OpenSearchVectorDBSerializer,
2120
FaissSerializer,
2221
RetrievalQASerializer,
2322
)
24-
from tests.unitary.with_extras.langchain.test_guardrails import FakeLLM
2523

2624

2725
class FakeEmbeddings(Serializable, Embeddings):
@@ -82,8 +80,7 @@ def test_type(self):
8280

8381
def test_save(self):
8482
serialized = self.serializer.save(self.opensearch)
85-
assert serialized["id"] == [
86-
"langchain",
83+
assert serialized["id"][-3:] == [
8784
"vectorstores",
8885
"opensearch_vector_search",
8986
"OpenSearchVectorSearch",
@@ -135,7 +132,7 @@ class TestRetrievalQASerializer(unittest.TestCase):
135132
@classmethod
136133
def setUpClass(cls):
137134
# Create a sample RetrieverQA object for testing
138-
cls.llm = FakeLLM()
135+
cls.llm = Cohere(cohere_api_key="api_key")
139136
cls.embeddings = FakeEmbeddings()
140137
text_embedding_pair = [("test", [1] * 1024)]
141138
try:
@@ -148,18 +145,6 @@ def setUpClass(cls):
148145
llm=cls.llm, chain_type="stuff", retriever=cls.retriever
149146
)
150147
cls.serializer = RetrievalQASerializer()
151-
from copy import deepcopy
152-
153-
cls.original_type_to_cls_dict = deepcopy(llms.get_type_to_cls_dict())
154-
__lc_llm_dict = llms.get_type_to_cls_dict()
155-
__lc_llm_dict["custom_embedding"] = lambda: FakeEmbeddings
156-
__lc_llm_dict["custom"] = lambda: FakeLLM
157-
158-
def __new_type_to_cls_dict():
159-
return __lc_llm_dict
160-
161-
llms.get_type_to_cls_dict = __new_type_to_cls_dict
162-
loading.get_type_to_cls_dict = __new_type_to_cls_dict
163148

164149
def test_type(self):
165150
self.assertEqual(self.serializer.type(), "retrieval_qa")
@@ -176,6 +161,7 @@ def test_save(self):
176161
self.assertIn("retriever_kwargs", serialized)
177162
serialized["vectordb"]["class"] == "FAISS"
178163

164+
@mock.patch.dict(os.environ, {"COHERE_API_KEY": "api_key"})
179165
def test_load(self):
180166
# Create a sample config dictionary
181167
serialized = self.serializer.save(self.qa)
@@ -186,12 +172,6 @@ def test_load(self):
186172
# Ensure that the deserialized object is an instance of RetrieverQA
187173
self.assertIsInstance(deserialized, RetrievalQA)
188174

189-
@classmethod
190-
def tearDownClass(cls) -> None:
191-
llms.get_type_to_cls_dict = cls.original_type_to_cls_dict
192-
loading.get_type_to_cls_dict = cls.original_type_to_cls_dict
193-
return super().tearDownClass()
194-
195175

196176
if __name__ == "__main__":
197177
unittest.main()

0 commit comments

Comments
 (0)