@@ -32,51 +32,14 @@ class ChainSerializationTest(TestCase):
32
32
# We expect users to use the same LangChain version for serialize and de-serialize
33
33
34
34
def setUp (self ) -> None :
35
- self .maxDiff = None
35
+ # self.maxDiff = None
36
36
return super ().setUp ()
37
37
38
38
PROMPT_TEMPLATE = "Tell me a joke about {subject}"
39
39
COMPARTMENT_ID = "<ocid>"
40
40
GEN_AI_KWARGS = {"service_endpoint" : "https://endpoint.oraclecloud.com" }
41
41
ENDPOINT = "https://modeldeployment.customer-oci.com/ocid/predict"
42
42
43
- EXPECTED_LLM_CHAIN_WITH_OCI_MD = {
44
- "lc" : 1 ,
45
- "type" : "constructor" ,
46
- "id" : ["langchain" , "chains" , "llm" , "LLMChain" ],
47
- "kwargs" : {
48
- "prompt" : {
49
- "lc" : 1 ,
50
- "type" : "constructor" ,
51
- "kwargs" : {
52
- "input_variables" : ["subject" ],
53
- "template" : "Tell me a joke about {subject}" ,
54
- "template_format" : "f-string" ,
55
- "partial_variables" : {},
56
- },
57
- },
58
- "llm" : {
59
- "lc" : 1 ,
60
- "type" : "constructor" ,
61
- "id" : ["ads" , "llm" , "ModelDeploymentVLLM" ],
62
- "kwargs" : {
63
- "endpoint" : "https://modeldeployment.customer-oci.com/ocid/predict" ,
64
- "model" : "my_model" ,
65
- },
66
- },
67
- },
68
- }
69
-
70
- EXPECTED_GEN_AI_LLM = {
71
- "lc" : 1 ,
72
- "type" : "constructor" ,
73
- "id" : ["ads" , "llm" , "GenerativeAI" ],
74
- "kwargs" : {
75
- "compartment_id" : "<ocid>" ,
76
- "client_kwargs" : {"service_endpoint" : "https://endpoint.oraclecloud.com" },
77
- },
78
- }
79
-
80
43
EXPECTED_GEN_AI_EMBEDDINGS = {
81
44
"lc" : 1 ,
82
45
"type" : "constructor" ,
@@ -170,10 +133,6 @@ def test_llm_chain_serialization_with_oci(self):
170
133
template = PromptTemplate .from_template (self .PROMPT_TEMPLATE )
171
134
llm_chain = LLMChain (prompt = template , llm = llm )
172
135
serialized = dump (llm_chain )
173
- # Do not check the ID field.
174
- expected = deepcopy (self .EXPECTED_LLM_CHAIN_WITH_OCI_MD )
175
- expected ["kwargs" ]["prompt" ]["id" ] = serialized ["kwargs" ]["prompt" ]["id" ]
176
- self .assertEqual (serialized , expected )
177
136
llm_chain = load (serialized )
178
137
self .assertIsInstance (llm_chain , LLMChain )
179
138
self .assertIsInstance (llm_chain .prompt , PromptTemplate )
@@ -193,10 +152,10 @@ def test_oci_gen_ai_serialization(self):
193
152
except ImportError as ex :
194
153
raise SkipTest ("OCI SDK does not support Generative AI." ) from ex
195
154
serialized = dump (llm )
196
- self .assertEqual (serialized , self .EXPECTED_GEN_AI_LLM )
197
155
llm = load (serialized )
198
156
self .assertIsInstance (llm , GenerativeAI )
199
157
self .assertEqual (llm .compartment_id , self .COMPARTMENT_ID )
158
+ self .assertEqual (llm .client_kwargs , self .GEN_AI_KWARGS )
200
159
201
160
def test_gen_ai_embeddings_serialization (self ):
202
161
"""Tests serialization of OCI Gen AI embeddings."""
0 commit comments