1
1
import os
2
- from unittest import TestCase , mock
2
+ from unittest import TestCase , mock , SkipTest
3
3
4
4
from langchain .llms import Cohere
5
5
from langchain .chains import LLMChain
6
6
from langchain .prompts import PromptTemplate
7
7
from langchain .schema .runnable import RunnableParallel , RunnablePassthrough
8
8
9
9
from ads .llm .serialize import load , dump
10
- from ads .llm import GenerativeAI , ModelDeploymentTGI , GenerativeAIEmbeddings
10
+ from ads .llm import (
11
+ GenerativeAI ,
12
+ GenerativeAIEmbeddings ,
13
+ ModelDeploymentTGI ,
14
+ ModelDeploymentVLLM ,
15
+ )
11
16
12
17
13
18
class ChainSerializationTest (TestCase ):
@@ -51,7 +56,7 @@ class ChainSerializationTest(TestCase):
51
56
"_type" : "llm_chain" ,
52
57
}
53
58
54
- EXPECTED_LLM_CHAIN_WITH_OCI_GEN_AI = {
59
+ EXPECTED_LLM_CHAIN_WITH_OCI_MD = {
55
60
"lc" : 1 ,
56
61
"type" : "constructor" ,
57
62
"id" : ["langchain" , "chains" , "llm" , "LLMChain" ],
@@ -70,12 +75,10 @@ class ChainSerializationTest(TestCase):
70
75
"llm" : {
71
76
"lc" : 1 ,
72
77
"type" : "constructor" ,
73
- "id" : ["ads" , "llm" , "GenerativeAI " ],
78
+ "id" : ["ads" , "llm" , "ModelDeploymentVLLM " ],
74
79
"kwargs" : {
75
- "compartment_id" : "<ocid>" ,
76
- "client_kwargs" : {
77
- "service_endpoint" : "https://endpoint.oraclecloud.com"
78
- },
80
+ "endpoint" : "https://modeldeployment.customer-oci.com/ocid/predict" ,
81
+ "model" : "my_model" ,
79
82
},
80
83
},
81
84
},
@@ -166,31 +169,31 @@ def test_llm_chain_serialization_with_cohere(self):
166
169
self .assertIsInstance (llm_chain .llm , Cohere )
167
170
self .assertEqual (llm_chain .input_keys , ["subject" ])
168
171
169
- def test_llm_chain_serialization_with_oci_gen_ai (self ):
172
+ def test_llm_chain_serialization_with_oci (self ):
170
173
"""Tests serialization of LLMChain with OCI Gen AI."""
171
- llm = GenerativeAI (
172
- compartment_id = self .COMPARTMENT_ID ,
173
- client_kwargs = self .GEN_AI_KWARGS ,
174
- )
174
+ llm = ModelDeploymentVLLM (endpoint = self .ENDPOINT , model = "my_model" )
175
175
template = PromptTemplate .from_template (self .PROMPT_TEMPLATE )
176
176
llm_chain = LLMChain (prompt = template , llm = llm )
177
177
serialized = dump (llm_chain )
178
- self .assertEqual (serialized , self .EXPECTED_LLM_CHAIN_WITH_OCI_GEN_AI )
178
+ self .assertEqual (serialized , self .EXPECTED_LLM_CHAIN_WITH_OCI_MD )
179
179
llm_chain = load (serialized )
180
180
self .assertIsInstance (llm_chain , LLMChain )
181
181
self .assertIsInstance (llm_chain .prompt , PromptTemplate )
182
182
self .assertEqual (llm_chain .prompt .template , self .PROMPT_TEMPLATE )
183
- self .assertIsInstance (llm_chain .llm , GenerativeAI )
184
- self .assertEqual (llm_chain .llm .compartment_id , self .COMPARTMENT_ID )
185
- self .assertEqual (llm_chain .llm .client_kwargs , self . GEN_AI_KWARGS )
183
+ self .assertIsInstance (llm_chain .llm , ModelDeploymentVLLM )
184
+ self .assertEqual (llm_chain .llm .endpoint , self .ENDPOINT )
185
+ self .assertEqual (llm_chain .llm .model , "my_model" )
186
186
self .assertEqual (llm_chain .input_keys , ["subject" ])
187
187
188
188
def test_oci_gen_ai_serialization (self ):
189
189
"""Tests serialization of OCI Gen AI LLM."""
190
- llm = GenerativeAI (
191
- compartment_id = self .COMPARTMENT_ID ,
192
- client_kwargs = self .GEN_AI_KWARGS ,
193
- )
190
+ try :
191
+ llm = GenerativeAI (
192
+ compartment_id = self .COMPARTMENT_ID ,
193
+ client_kwargs = self .GEN_AI_KWARGS ,
194
+ )
195
+ except ImportError :
196
+ raise SkipTest ("OCI SDK does not support Generative AI." )
194
197
serialized = dump (llm )
195
198
self .assertEqual (serialized , self .EXPECTED_GEN_AI_LLM )
196
199
llm = load (serialized )
@@ -199,9 +202,12 @@ def test_oci_gen_ai_serialization(self):
199
202
200
203
def test_gen_ai_embeddings_serialization (self ):
201
204
"""Tests serialization of OCI Gen AI embeddings."""
202
- embeddings = GenerativeAIEmbeddings (
203
- compartment_id = self .COMPARTMENT_ID , client_kwargs = self .GEN_AI_KWARGS
204
- )
205
+ try :
206
+ embeddings = GenerativeAIEmbeddings (
207
+ compartment_id = self .COMPARTMENT_ID , client_kwargs = self .GEN_AI_KWARGS
208
+ )
209
+ except ImportError :
210
+ raise SkipTest ("OCI SDK does not support Generative AI." )
205
211
serialized = dump (embeddings )
206
212
self .assertEqual (serialized , self .EXPECTED_GEN_AI_EMBEDDINGS )
207
213
embeddings = load (serialized )
0 commit comments