6
6
7
7
8
8
import os
9
+ from copy import deepcopy
9
10
from unittest import TestCase , mock , SkipTest
10
11
11
12
from langchain .llms import Cohere
25
26
class ChainSerializationTest (TestCase ):
26
27
"""Contains tests for chain serialization."""
27
28
29
+ # LangChain is updating frequently on the module organization,
30
+ # mainly affecting the id field of the serialization.
31
+ # In the test, we will not check the id field of some components.
32
+ # We expect users to use the same LangChain version for serialize and de-serialize
33
+
28
34
def setUp (self ) -> None :
29
35
self .maxDiff = None
30
36
return super ().setUp ()
@@ -75,7 +81,6 @@ def setUp(self) -> None:
75
81
"prompt" : {
76
82
"lc" : 1 ,
77
83
"type" : "constructor" ,
78
- "id" : ["langchain_core" , "prompts" , "prompt" , "PromptTemplate" ],
79
84
"kwargs" : {
80
85
"input_variables" : ["subject" ],
81
86
"template" : "Tell me a joke about {subject}" ,
@@ -118,12 +123,10 @@ def setUp(self) -> None:
118
123
EXPECTED_RUNNABLE_SEQUENCE = {
119
124
"lc" : 1 ,
120
125
"type" : "constructor" ,
121
- "id" : ["langchain_core" , "runnables" , "RunnableSequence" ],
122
126
"kwargs" : {
123
127
"first" : {
124
128
"lc" : 1 ,
125
129
"type" : "constructor" ,
126
- "id" : ["langchain_core" , "runnables" , "RunnableParallel" ],
127
130
"kwargs" : {
128
131
"steps" : {
129
132
"text" : {
@@ -144,7 +147,6 @@ def setUp(self) -> None:
144
147
{
145
148
"lc" : 1 ,
146
149
"type" : "constructor" ,
147
- "id" : ["langchain_core" , "prompts" , "prompt" , "PromptTemplate" ],
148
150
"kwargs" : {
149
151
"input_variables" : ["subject" ],
150
152
"template" : "Tell me a joke about {subject}" ,
@@ -185,7 +187,10 @@ def test_llm_chain_serialization_with_oci(self):
185
187
template = PromptTemplate .from_template (self .PROMPT_TEMPLATE )
186
188
llm_chain = LLMChain (prompt = template , llm = llm )
187
189
serialized = dump (llm_chain )
188
- self .assertEqual (serialized , self .EXPECTED_LLM_CHAIN_WITH_OCI_MD )
190
+ # Do not check the ID field.
191
+ expected = deepcopy (self .EXPECTED_LLM_CHAIN_WITH_OCI_MD )
192
+ expected ["kwargs" ]["prompt" ]["id" ] = serialized ["kwargs" ]["prompt" ]["id" ]
193
+ self .assertEqual (serialized , expected )
189
194
llm_chain = load (serialized )
190
195
self .assertIsInstance (llm_chain , LLMChain )
191
196
self .assertIsInstance (llm_chain .prompt , PromptTemplate )
@@ -202,8 +207,8 @@ def test_oci_gen_ai_serialization(self):
202
207
compartment_id = self .COMPARTMENT_ID ,
203
208
client_kwargs = self .GEN_AI_KWARGS ,
204
209
)
205
- except ImportError :
206
- raise SkipTest ("OCI SDK does not support Generative AI." )
210
+ except ImportError as ex :
211
+ raise SkipTest ("OCI SDK does not support Generative AI." ) from ex
207
212
serialized = dump (llm )
208
213
self .assertEqual (serialized , self .EXPECTED_GEN_AI_LLM )
209
214
llm = load (serialized )
@@ -216,8 +221,8 @@ def test_gen_ai_embeddings_serialization(self):
216
221
embeddings = GenerativeAIEmbeddings (
217
222
compartment_id = self .COMPARTMENT_ID , client_kwargs = self .GEN_AI_KWARGS
218
223
)
219
- except ImportError :
220
- raise SkipTest ("OCI SDK does not support Generative AI." )
224
+ except ImportError as ex :
225
+ raise SkipTest ("OCI SDK does not support Generative AI." ) from ex
221
226
serialized = dump (embeddings )
222
227
self .assertEqual (serialized , self .EXPECTED_GEN_AI_EMBEDDINGS )
223
228
embeddings = load (serialized )
@@ -232,7 +237,15 @@ def test_runnable_sequence_serialization(self):
232
237
233
238
chain = map_input | template | llm
234
239
serialized = dump (chain )
235
- self .assertEqual (serialized , self .EXPECTED_RUNNABLE_SEQUENCE )
240
+ # Do not check the ID fields.
241
+ expected = deepcopy (self .EXPECTED_RUNNABLE_SEQUENCE )
242
+ expected ["id" ] = serialized ["id" ]
243
+ expected ["kwargs" ]["first" ]["id" ] = serialized ["kwargs" ]["first" ]["id" ]
244
+ expected ["kwargs" ]["first" ]["kwargs" ]["steps" ]["text" ]["id" ] = serialized [
245
+ "kwargs"
246
+ ]["first" ]["kwargs" ]["steps" ]["text" ]["id" ]
247
+ expected ["kwargs" ]["middle" ][0 ]["id" ] = serialized ["kwargs" ]["middle" ][0 ]["id" ]
248
+ self .assertEqual (serialized , expected )
236
249
chain = load (serialized )
237
250
self .assertEqual (len (chain .steps ), 3 )
238
251
self .assertIsInstance (chain .steps [0 ], RunnableParallel )
0 commit comments