Skip to content

Commit 4b9b458

Browse files
committed
Update langchain serialization tests.
1 parent 55bc02a commit 4b9b458

File tree

1 file changed

+44
-45
lines changed

1 file changed

+44
-45
lines changed

tests/unitary/with_extras/langchain/test_serialization.py

Lines changed: 44 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -40,39 +40,6 @@ def setUp(self) -> None:
4040
GEN_AI_KWARGS = {"service_endpoint": "https://endpoint.oraclecloud.com"}
4141
ENDPOINT = "https://modeldeployment.customer-oci.com/ocid/predict"
4242

43-
EXPECTED_LLM_CHAIN_WITH_COHERE = {
44-
"memory": None,
45-
"verbose": True,
46-
"tags": None,
47-
"metadata": None,
48-
"prompt": {
49-
"input_variables": ["subject"],
50-
"input_types": {},
51-
"output_parser": None,
52-
"partial_variables": {},
53-
"template": "Tell me a joke about {subject}",
54-
"template_format": "f-string",
55-
"validate_template": False,
56-
"_type": "prompt",
57-
},
58-
"llm": {
59-
"model": None,
60-
"max_tokens": 256,
61-
"temperature": 0.75,
62-
"k": 0,
63-
"p": 1,
64-
"frequency_penalty": 0.0,
65-
"presence_penalty": 0.0,
66-
"truncate": None,
67-
"_type": "cohere",
68-
},
69-
"output_key": "text",
70-
"output_parser": {"_type": "default"},
71-
"return_final_only": True,
72-
"llm_kwargs": {},
73-
"_type": "llm_chain",
74-
}
75-
7643
EXPECTED_LLM_CHAIN_WITH_OCI_MD = {
7744
"lc": 1,
7845
"type": "constructor",
@@ -173,7 +140,23 @@ def test_llm_chain_serialization_with_cohere(self):
173140
template = PromptTemplate.from_template(self.PROMPT_TEMPLATE)
174141
llm_chain = LLMChain(prompt=template, llm=llm, verbose=True)
175142
serialized = dump(llm_chain)
176-
self.assertEqual(serialized, self.EXPECTED_LLM_CHAIN_WITH_COHERE)
143+
144+
# Check the serialized chain
145+
self.assertTrue(serialized.get("verbose"))
146+
self.assertEqual(serialized.get("_type"), "llm_chain")
147+
148+
# Check the serialized prompt template
149+
serialized_prompt = serialized.get("prompt")
150+
self.assertIsInstance(serialized_prompt, dict)
151+
self.assertEqual(serialized_prompt.get("_type"), "prompt")
152+
self.assertEqual(set(serialized_prompt.get("input_variables")), {"subject"})
153+
self.assertEqual(serialized_prompt.get("template"), self.PROMPT_TEMPLATE)
154+
155+
# Check the serialized LLM
156+
serialized_llm = serialized.get("llm")
157+
self.assertIsInstance(serialized_llm, dict)
158+
self.assertEqual(serialized_llm.get("_type"), "cohere")
159+
177160
llm_chain = load(serialized)
178161
self.assertIsInstance(llm_chain, LLMChain)
179162
self.assertIsInstance(llm_chain.prompt, PromptTemplate)
@@ -237,21 +220,37 @@ def test_runnable_sequence_serialization(self):
237220

238221
chain = map_input | template | llm
239222
serialized = dump(chain)
240-
# Do not check the ID fields.
241-
expected = deepcopy(self.EXPECTED_RUNNABLE_SEQUENCE)
242-
expected["id"] = serialized["id"]
243-
expected["kwargs"]["first"]["id"] = serialized["kwargs"]["first"]["id"]
244-
expected["kwargs"]["first"]["kwargs"]["steps"]["text"]["id"] = serialized[
245-
"kwargs"
246-
]["first"]["kwargs"]["steps"]["text"]["id"]
247-
expected["kwargs"]["middle"][0]["id"] = serialized["kwargs"]["middle"][0]["id"]
248-
self.assertEqual(serialized, expected)
223+
224+
self.assertEqual(serialized.get("type"), "constructor")
225+
self.assertNotIn("_type", serialized)
226+
227+
kwargs = serialized.get("kwargs")
228+
self.assertIsInstance(kwargs, dict)
229+
230+
element_1 = kwargs.get("first")
231+
self.assertEqual(element_1.get("_type"), "RunnableParallel")
232+
step = element_1.get("kwargs").get("steps").get("text")
233+
self.assertEqual(step.get("id")[-1], "RunnablePassthrough")
234+
235+
element_2 = kwargs.get("middle")[0]
236+
self.assertNotIn("_type", element_2)
237+
self.assertEqual(element_2.get("kwargs").get("template"), self.PROMPT_TEMPLATE)
238+
self.assertEqual(element_2.get("kwargs").get("input_variables"), ["subject"])
239+
240+
element_3 = kwargs.get("last")
241+
self.assertNotIn("_type", element_3)
242+
self.assertEqual(element_3.get("id"), ["ads", "llm", "ModelDeploymentTGI"])
243+
self.assertEqual(
244+
element_3.get("kwargs"),
245+
{"endpoint": "https://modeldeployment.customer-oci.com/ocid/predict"},
246+
)
247+
249248
chain = load(serialized)
250249
self.assertEqual(len(chain.steps), 3)
251250
self.assertIsInstance(chain.steps[0], RunnableParallel)
252251
self.assertEqual(
253-
chain.steps[0].dict(),
254-
{"steps": {"text": {"input_type": None, "func": None, "afunc": None}}},
252+
list(chain.steps[0].dict().get("steps").keys()),
253+
["text"],
255254
)
256255
self.assertIsInstance(chain.steps[1], PromptTemplate)
257256
self.assertIsInstance(chain.steps[2], ModelDeploymentTGI)

0 commit comments

Comments
 (0)