Skip to content

Commit 7fbbd67

Browse files
committed
Update LangChain serialization tests to skip checking ID fields.
1 parent 02d940f commit 7fbbd67

File tree

2 files changed

+24
-12
lines changed

2 files changed

+24
-12
lines changed

tests/unitary/with_extras/langchain/test_serialization.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77

88
import os
9+
from copy import deepcopy
910
from unittest import TestCase, mock, SkipTest
1011

1112
from langchain.llms import Cohere
@@ -25,6 +26,11 @@
2526
class ChainSerializationTest(TestCase):
2627
"""Contains tests for chain serialization."""
2728

29+
# LangChain is updating frequently on the module organization,
30+
# mainly affecting the id field of the serialization.
31+
# In the test, we will not check the id field of some components.
32+
# We expect users to use the same LangChain version for serialize and de-serialize
33+
2834
def setUp(self) -> None:
2935
self.maxDiff = None
3036
return super().setUp()
@@ -75,7 +81,6 @@ def setUp(self) -> None:
7581
"prompt": {
7682
"lc": 1,
7783
"type": "constructor",
78-
"id": ["langchain_core", "prompts", "prompt", "PromptTemplate"],
7984
"kwargs": {
8085
"input_variables": ["subject"],
8186
"template": "Tell me a joke about {subject}",
@@ -118,12 +123,10 @@ def setUp(self) -> None:
118123
EXPECTED_RUNNABLE_SEQUENCE = {
119124
"lc": 1,
120125
"type": "constructor",
121-
"id": ["langchain_core", "runnables", "RunnableSequence"],
122126
"kwargs": {
123127
"first": {
124128
"lc": 1,
125129
"type": "constructor",
126-
"id": ["langchain_core", "runnables", "RunnableParallel"],
127130
"kwargs": {
128131
"steps": {
129132
"text": {
@@ -144,7 +147,6 @@ def setUp(self) -> None:
144147
{
145148
"lc": 1,
146149
"type": "constructor",
147-
"id": ["langchain_core", "prompts", "prompt", "PromptTemplate"],
148150
"kwargs": {
149151
"input_variables": ["subject"],
150152
"template": "Tell me a joke about {subject}",
@@ -185,7 +187,10 @@ def test_llm_chain_serialization_with_oci(self):
185187
template = PromptTemplate.from_template(self.PROMPT_TEMPLATE)
186188
llm_chain = LLMChain(prompt=template, llm=llm)
187189
serialized = dump(llm_chain)
188-
self.assertEqual(serialized, self.EXPECTED_LLM_CHAIN_WITH_OCI_MD)
190+
# Do not check the ID field.
191+
expected = deepcopy(self.EXPECTED_LLM_CHAIN_WITH_OCI_MD)
192+
expected["kwargs"]["prompt"]["id"] = serialized["kwargs"]["prompt"]["id"]
193+
self.assertEqual(serialized, expected)
189194
llm_chain = load(serialized)
190195
self.assertIsInstance(llm_chain, LLMChain)
191196
self.assertIsInstance(llm_chain.prompt, PromptTemplate)
@@ -202,8 +207,8 @@ def test_oci_gen_ai_serialization(self):
202207
compartment_id=self.COMPARTMENT_ID,
203208
client_kwargs=self.GEN_AI_KWARGS,
204209
)
205-
except ImportError:
206-
raise SkipTest("OCI SDK does not support Generative AI.")
210+
except ImportError as ex:
211+
raise SkipTest("OCI SDK does not support Generative AI.") from ex
207212
serialized = dump(llm)
208213
self.assertEqual(serialized, self.EXPECTED_GEN_AI_LLM)
209214
llm = load(serialized)
@@ -216,8 +221,8 @@ def test_gen_ai_embeddings_serialization(self):
216221
embeddings = GenerativeAIEmbeddings(
217222
compartment_id=self.COMPARTMENT_ID, client_kwargs=self.GEN_AI_KWARGS
218223
)
219-
except ImportError:
220-
raise SkipTest("OCI SDK does not support Generative AI.")
224+
except ImportError as ex:
225+
raise SkipTest("OCI SDK does not support Generative AI.") from ex
221226
serialized = dump(embeddings)
222227
self.assertEqual(serialized, self.EXPECTED_GEN_AI_EMBEDDINGS)
223228
embeddings = load(serialized)
@@ -232,7 +237,15 @@ def test_runnable_sequence_serialization(self):
232237

233238
chain = map_input | template | llm
234239
serialized = dump(chain)
235-
self.assertEqual(serialized, self.EXPECTED_RUNNABLE_SEQUENCE)
240+
# Do not check the ID fields.
241+
expected = deepcopy(self.EXPECTED_RUNNABLE_SEQUENCE)
242+
expected["id"] = serialized["id"]
243+
expected["kwargs"]["first"]["id"] = serialized["kwargs"]["first"]["id"]
244+
expected["kwargs"]["first"]["kwargs"]["steps"]["text"]["id"] = serialized[
245+
"kwargs"
246+
]["first"]["kwargs"]["steps"]["text"]["id"]
247+
expected["kwargs"]["middle"][0]["id"] = serialized["kwargs"]["middle"][0]["id"]
248+
self.assertEqual(serialized, expected)
236249
chain = load(serialized)
237250
self.assertEqual(len(chain.steps), 3)
238251
self.assertIsInstance(chain.steps[0], RunnableParallel)

tests/unitary/with_extras/langchain/test_serializers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ def test_type(self):
8080

8181
def test_save(self):
8282
serialized = self.serializer.save(self.opensearch)
83-
assert serialized["id"] == [
84-
"langchain",
83+
assert serialized["id"][-3:] == [
8584
"vectorstores",
8685
"opensearch_vector_search",
8786
"OpenSearchVectorSearch",

0 commit comments

Comments
 (0)