Skip to content

Commit a775696

Browse files
committed
Add test_serialization.py
1 parent 7b36274 commit a775696

File tree

1 file changed

+229
-0
lines changed

1 file changed

+229
-0
lines changed
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
import os
2+
from unittest import TestCase, mock
3+
4+
from langchain.llms import Cohere
5+
from langchain.chains import LLMChain
6+
from langchain.prompts import PromptTemplate
7+
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
8+
9+
from ads.llm.serialize import load, dump
10+
from ads.llm import GenerativeAI, ModelDeploymentTGI, GenerativeAIEmbeddings
11+
12+
13+
class ChainSerializationTest(TestCase):
14+
"""Contains tests for chain serialization."""
15+
16+
PROMPT_TEMPLATE = "Tell me a joke about {subject}"
17+
COMPARTMENT_ID = "<ocid>"
18+
GEN_AI_KWARGS = {"service_endpoint": "https://endpoint.oraclecloud.com"}
19+
ENDPOINT = "https://modeldeployment.customer-oci.com/ocid/predict"
20+
21+
EXPECTED_LLM_CHAIN_WITH_COHERE = {
22+
"memory": None,
23+
"verbose": True,
24+
"tags": None,
25+
"metadata": None,
26+
"prompt": {
27+
"input_variables": ["subject"],
28+
"input_types": {},
29+
"output_parser": None,
30+
"partial_variables": {},
31+
"template": "Tell me a joke about {subject}",
32+
"template_format": "f-string",
33+
"validate_template": False,
34+
"_type": "prompt",
35+
},
36+
"llm": {
37+
"model": None,
38+
"max_tokens": 256,
39+
"temperature": 0.75,
40+
"k": 0,
41+
"p": 1,
42+
"frequency_penalty": 0.0,
43+
"presence_penalty": 0.0,
44+
"truncate": None,
45+
"_type": "cohere",
46+
},
47+
"output_key": "text",
48+
"output_parser": {"_type": "default"},
49+
"return_final_only": True,
50+
"llm_kwargs": {},
51+
"_type": "llm_chain",
52+
}
53+
54+
EXPECTED_LLM_CHAIN_WITH_OCI_GEN_AI = {
55+
"lc": 1,
56+
"type": "constructor",
57+
"id": ["langchain", "chains", "llm", "LLMChain"],
58+
"kwargs": {
59+
"prompt": {
60+
"lc": 1,
61+
"type": "constructor",
62+
"id": ["langchain", "prompts", "prompt", "PromptTemplate"],
63+
"kwargs": {
64+
"input_variables": ["subject"],
65+
"template": "Tell me a joke about {subject}",
66+
"template_format": "f-string",
67+
"partial_variables": {},
68+
},
69+
},
70+
"llm": {
71+
"lc": 1,
72+
"type": "constructor",
73+
"id": ["ads", "llm", "GenerativeAI"],
74+
"kwargs": {
75+
"compartment_id": "<ocid>",
76+
"client_kwargs": {
77+
"service_endpoint": "https://endpoint.oraclecloud.com"
78+
},
79+
},
80+
},
81+
},
82+
}
83+
84+
EXPECTED_GEN_AI_LLM = {
85+
"lc": 1,
86+
"type": "constructor",
87+
"id": ["ads", "llm", "GenerativeAI"],
88+
"kwargs": {
89+
"compartment_id": "<ocid>",
90+
"client_kwargs": {"service_endpoint": "https://endpoint.oraclecloud.com"},
91+
},
92+
}
93+
94+
EXPECTED_GEN_AI_EMBEDDINGS = {
95+
"lc": 1,
96+
"type": "constructor",
97+
"id": ["ads", "llm", "GenerativeAIEmbeddings"],
98+
"kwargs": {
99+
"compartment_id": "<ocid>",
100+
"client_kwargs": {"service_endpoint": "https://endpoint.oraclecloud.com"},
101+
},
102+
}
103+
104+
EXPECTED_RUNNABLE_SEQUENCE = {
105+
"lc": 1,
106+
"type": "constructor",
107+
"id": ["langchain", "schema", "runnable", "RunnableSequence"],
108+
"kwargs": {
109+
"first": {
110+
"lc": 1,
111+
"type": "constructor",
112+
"id": ["langchain", "schema", "runnable", "RunnableParallel"],
113+
"kwargs": {
114+
"steps": {
115+
"text": {
116+
"lc": 1,
117+
"type": "constructor",
118+
"id": [
119+
"langchain",
120+
"schema",
121+
"runnable",
122+
"RunnablePassthrough",
123+
],
124+
"kwargs": {"func": None, "afunc": None, "input_type": None},
125+
}
126+
}
127+
},
128+
"_type": "RunnableParallel",
129+
},
130+
"middle": [
131+
{
132+
"lc": 1,
133+
"type": "constructor",
134+
"id": ["langchain", "prompts", "prompt", "PromptTemplate"],
135+
"kwargs": {
136+
"input_variables": ["subject"],
137+
"template": "Tell me a joke about {subject}",
138+
"template_format": "f-string",
139+
"partial_variables": {},
140+
},
141+
}
142+
],
143+
"last": {
144+
"lc": 1,
145+
"type": "constructor",
146+
"id": ["ads", "llm", "ModelDeploymentTGI"],
147+
"kwargs": {
148+
"endpoint": "https://modeldeployment.customer-oci.com/ocid/predict"
149+
},
150+
},
151+
},
152+
}
153+
154+
@mock.patch.dict(os.environ, {"COHERE_API_KEY": "api_key"})
155+
def test_llm_chain_serialization_with_cohere(self):
156+
"""Tests serialization of LLMChain with Cohere."""
157+
llm = Cohere()
158+
template = PromptTemplate.from_template(self.PROMPT_TEMPLATE)
159+
llm_chain = LLMChain(prompt=template, llm=llm, verbose=True)
160+
serialized = dump(llm_chain)
161+
self.assertEqual(serialized, self.EXPECTED_LLM_CHAIN_WITH_COHERE)
162+
llm_chain = load(serialized)
163+
self.assertIsInstance(llm_chain, LLMChain)
164+
self.assertIsInstance(llm_chain.prompt, PromptTemplate)
165+
self.assertEqual(llm_chain.prompt.template, self.PROMPT_TEMPLATE)
166+
self.assertIsInstance(llm_chain.llm, Cohere)
167+
self.assertEqual(llm_chain.input_keys, ["subject"])
168+
169+
def test_llm_chain_serialization_with_oci_gen_ai(self):
170+
"""Tests serialization of LLMChain with OCI Gen AI."""
171+
llm = GenerativeAI(
172+
compartment_id=self.COMPARTMENT_ID,
173+
client_kwargs=self.GEN_AI_KWARGS,
174+
)
175+
template = PromptTemplate.from_template(self.PROMPT_TEMPLATE)
176+
llm_chain = LLMChain(prompt=template, llm=llm)
177+
serialized = dump(llm_chain)
178+
self.assertEqual(serialized, self.EXPECTED_LLM_CHAIN_WITH_OCI_GEN_AI)
179+
llm_chain = load(serialized)
180+
self.assertIsInstance(llm_chain, LLMChain)
181+
self.assertIsInstance(llm_chain.prompt, PromptTemplate)
182+
self.assertEqual(llm_chain.prompt.template, self.PROMPT_TEMPLATE)
183+
self.assertIsInstance(llm_chain.llm, GenerativeAI)
184+
self.assertEqual(llm_chain.llm.compartment_id, self.COMPARTMENT_ID)
185+
self.assertEqual(llm_chain.llm.client_kwargs, self.GEN_AI_KWARGS)
186+
self.assertEqual(llm_chain.input_keys, ["subject"])
187+
188+
def test_oci_gen_ai_serialization(self):
189+
"""Tests serialization of OCI Gen AI LLM."""
190+
llm = GenerativeAI(
191+
compartment_id=self.COMPARTMENT_ID,
192+
client_kwargs=self.GEN_AI_KWARGS,
193+
)
194+
serialized = dump(llm)
195+
self.assertEqual(serialized, self.EXPECTED_GEN_AI_LLM)
196+
llm = load(serialized)
197+
self.assertIsInstance(llm, GenerativeAI)
198+
self.assertEqual(llm.compartment_id, self.COMPARTMENT_ID)
199+
200+
def test_gen_ai_embeddings_serialization(self):
201+
"""Tests serialization of OCI Gen AI embeddings."""
202+
embeddings = GenerativeAIEmbeddings(
203+
compartment_id=self.COMPARTMENT_ID, client_kwargs=self.GEN_AI_KWARGS
204+
)
205+
serialized = dump(embeddings)
206+
self.assertEqual(serialized, self.EXPECTED_GEN_AI_EMBEDDINGS)
207+
embeddings = load(serialized)
208+
self.assertIsInstance(embeddings, GenerativeAIEmbeddings)
209+
self.assertEqual(embeddings.compartment_id, self.COMPARTMENT_ID)
210+
211+
def test_runnable_sequence_serialization(self):
212+
"""Tests serialization of runnable sequence."""
213+
map_input = RunnableParallel(text=RunnablePassthrough())
214+
template = PromptTemplate.from_template(self.PROMPT_TEMPLATE)
215+
llm = ModelDeploymentTGI(endpoint=self.ENDPOINT)
216+
217+
chain = map_input | template | llm
218+
serialized = dump(chain)
219+
self.assertEqual(serialized, self.EXPECTED_RUNNABLE_SEQUENCE)
220+
chain = load(serialized)
221+
self.assertEqual(len(chain.steps), 3)
222+
self.assertIsInstance(chain.steps[0], RunnableParallel)
223+
self.assertEqual(
224+
chain.steps[0].dict(),
225+
{"steps": {"text": {"input_type": None, "func": None, "afunc": None}}},
226+
)
227+
self.assertIsInstance(chain.steps[1], PromptTemplate)
228+
self.assertIsInstance(chain.steps[2], ModelDeploymentTGI)
229+
self.assertEqual(chain.steps[2].endpoint, self.ENDPOINT)

0 commit comments

Comments
 (0)