@@ -40,39 +40,6 @@ def setUp(self) -> None:
40
40
GEN_AI_KWARGS = {"service_endpoint" : "https://endpoint.oraclecloud.com" }
41
41
ENDPOINT = "https://modeldeployment.customer-oci.com/ocid/predict"
42
42
43
- EXPECTED_LLM_CHAIN_WITH_COHERE = {
44
- "memory" : None ,
45
- "verbose" : True ,
46
- "tags" : None ,
47
- "metadata" : None ,
48
- "prompt" : {
49
- "input_variables" : ["subject" ],
50
- "input_types" : {},
51
- "output_parser" : None ,
52
- "partial_variables" : {},
53
- "template" : "Tell me a joke about {subject}" ,
54
- "template_format" : "f-string" ,
55
- "validate_template" : False ,
56
- "_type" : "prompt" ,
57
- },
58
- "llm" : {
59
- "model" : None ,
60
- "max_tokens" : 256 ,
61
- "temperature" : 0.75 ,
62
- "k" : 0 ,
63
- "p" : 1 ,
64
- "frequency_penalty" : 0.0 ,
65
- "presence_penalty" : 0.0 ,
66
- "truncate" : None ,
67
- "_type" : "cohere" ,
68
- },
69
- "output_key" : "text" ,
70
- "output_parser" : {"_type" : "default" },
71
- "return_final_only" : True ,
72
- "llm_kwargs" : {},
73
- "_type" : "llm_chain" ,
74
- }
75
-
76
43
EXPECTED_LLM_CHAIN_WITH_OCI_MD = {
77
44
"lc" : 1 ,
78
45
"type" : "constructor" ,
@@ -173,7 +140,23 @@ def test_llm_chain_serialization_with_cohere(self):
173
140
template = PromptTemplate .from_template (self .PROMPT_TEMPLATE )
174
141
llm_chain = LLMChain (prompt = template , llm = llm , verbose = True )
175
142
serialized = dump (llm_chain )
176
- self .assertEqual (serialized , self .EXPECTED_LLM_CHAIN_WITH_COHERE )
143
+
144
+ # Check the serialized chain
145
+ self .assertTrue (serialized .get ("verbose" ))
146
+ self .assertEqual (serialized .get ("_type" ), "llm_chain" )
147
+
148
+ # Check the serialized prompt template
149
+ serialized_prompt = serialized .get ("prompt" )
150
+ self .assertIsInstance (serialized_prompt , dict )
151
+ self .assertEqual (serialized_prompt .get ("_type" ), "prompt" )
152
+ self .assertEqual (set (serialized_prompt .get ("input_variables" )), {"subject" })
153
+ self .assertEqual (serialized_prompt .get ("template" ), self .PROMPT_TEMPLATE )
154
+
155
+ # Check the serialized LLM
156
+ serialized_llm = serialized .get ("llm" )
157
+ self .assertIsInstance (serialized_llm , dict )
158
+ self .assertEqual (serialized_llm .get ("_type" ), "cohere" )
159
+
177
160
llm_chain = load (serialized )
178
161
self .assertIsInstance (llm_chain , LLMChain )
179
162
self .assertIsInstance (llm_chain .prompt , PromptTemplate )
@@ -237,21 +220,37 @@ def test_runnable_sequence_serialization(self):
237
220
238
221
chain = map_input | template | llm
239
222
serialized = dump (chain )
240
- # Do not check the ID fields.
241
- expected = deepcopy (self .EXPECTED_RUNNABLE_SEQUENCE )
242
- expected ["id" ] = serialized ["id" ]
243
- expected ["kwargs" ]["first" ]["id" ] = serialized ["kwargs" ]["first" ]["id" ]
244
- expected ["kwargs" ]["first" ]["kwargs" ]["steps" ]["text" ]["id" ] = serialized [
245
- "kwargs"
246
- ]["first" ]["kwargs" ]["steps" ]["text" ]["id" ]
247
- expected ["kwargs" ]["middle" ][0 ]["id" ] = serialized ["kwargs" ]["middle" ][0 ]["id" ]
248
- self .assertEqual (serialized , expected )
223
+
224
+ self .assertEqual (serialized .get ("type" ), "constructor" )
225
+ self .assertNotIn ("_type" , serialized )
226
+
227
+ kwargs = serialized .get ("kwargs" )
228
+ self .assertIsInstance (kwargs , dict )
229
+
230
+ element_1 = kwargs .get ("first" )
231
+ self .assertEqual (element_1 .get ("_type" ), "RunnableParallel" )
232
+ step = element_1 .get ("kwargs" ).get ("steps" ).get ("text" )
233
+ self .assertEqual (step .get ("id" )[- 1 ], "RunnablePassthrough" )
234
+
235
+ element_2 = kwargs .get ("middle" )[0 ]
236
+ self .assertNotIn ("_type" , element_2 )
237
+ self .assertEqual (element_2 .get ("kwargs" ).get ("template" ), self .PROMPT_TEMPLATE )
238
+ self .assertEqual (element_2 .get ("kwargs" ).get ("input_variables" ), ["subject" ])
239
+
240
+ element_3 = kwargs .get ("last" )
241
+ self .assertNotIn ("_type" , element_3 )
242
+ self .assertEqual (element_3 .get ("id" ), ["ads" , "llm" , "ModelDeploymentTGI" ])
243
+ self .assertEqual (
244
+ element_3 .get ("kwargs" ),
245
+ {"endpoint" : "https://modeldeployment.customer-oci.com/ocid/predict" },
246
+ )
247
+
249
248
chain = load (serialized )
250
249
self .assertEqual (len (chain .steps ), 3 )
251
250
self .assertIsInstance (chain .steps [0 ], RunnableParallel )
252
251
self .assertEqual (
253
- chain .steps [0 ].dict (),
254
- { "steps" : { " text": { "input_type" : None , "func" : None , "afunc" : None }}} ,
252
+ list ( chain .steps [0 ].dict (). get ( "steps" ). keys () ),
253
+ [ " text"] ,
255
254
)
256
255
self .assertIsInstance (chain .steps [1 ], PromptTemplate )
257
256
self .assertIsInstance (chain .steps [2 ], ModelDeploymentTGI )
0 commit comments