Skip to content

Commit c9948d4

Browse files
committed
Update test_serialization.py to skip test if Gen AI not available.
1 parent 4ff91a4 commit c9948d4

File tree

1 file changed

+30
-24
lines changed

1 file changed

+30
-24
lines changed

tests/unitary/with_extras/langchain/test_serialization.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
import os
2-
from unittest import TestCase, mock
2+
from unittest import TestCase, mock, SkipTest
33

44
from langchain.llms import Cohere
55
from langchain.chains import LLMChain
66
from langchain.prompts import PromptTemplate
77
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
88

99
from ads.llm.serialize import load, dump
10-
from ads.llm import GenerativeAI, ModelDeploymentTGI, GenerativeAIEmbeddings
10+
from ads.llm import (
11+
GenerativeAI,
12+
GenerativeAIEmbeddings,
13+
ModelDeploymentTGI,
14+
ModelDeploymentVLLM,
15+
)
1116

1217

1318
class ChainSerializationTest(TestCase):
@@ -51,7 +56,7 @@ class ChainSerializationTest(TestCase):
5156
"_type": "llm_chain",
5257
}
5358

54-
EXPECTED_LLM_CHAIN_WITH_OCI_GEN_AI = {
59+
EXPECTED_LLM_CHAIN_WITH_OCI_MD = {
5560
"lc": 1,
5661
"type": "constructor",
5762
"id": ["langchain", "chains", "llm", "LLMChain"],
@@ -70,12 +75,10 @@ class ChainSerializationTest(TestCase):
7075
"llm": {
7176
"lc": 1,
7277
"type": "constructor",
73-
"id": ["ads", "llm", "GenerativeAI"],
78+
"id": ["ads", "llm", "ModelDeploymentVLLM"],
7479
"kwargs": {
75-
"compartment_id": "<ocid>",
76-
"client_kwargs": {
77-
"service_endpoint": "https://endpoint.oraclecloud.com"
78-
},
80+
"endpoint": "https://modeldeployment.customer-oci.com/ocid/predict",
81+
"model": "my_model",
7982
},
8083
},
8184
},
@@ -166,31 +169,31 @@ def test_llm_chain_serialization_with_cohere(self):
166169
self.assertIsInstance(llm_chain.llm, Cohere)
167170
self.assertEqual(llm_chain.input_keys, ["subject"])
168171

169-
def test_llm_chain_serialization_with_oci_gen_ai(self):
172+
def test_llm_chain_serialization_with_oci(self):
170173
"""Tests serialization of LLMChain with OCI Gen AI."""
171-
llm = GenerativeAI(
172-
compartment_id=self.COMPARTMENT_ID,
173-
client_kwargs=self.GEN_AI_KWARGS,
174-
)
174+
llm = ModelDeploymentVLLM(endpoint=self.ENDPOINT, model="my_model")
175175
template = PromptTemplate.from_template(self.PROMPT_TEMPLATE)
176176
llm_chain = LLMChain(prompt=template, llm=llm)
177177
serialized = dump(llm_chain)
178-
self.assertEqual(serialized, self.EXPECTED_LLM_CHAIN_WITH_OCI_GEN_AI)
178+
self.assertEqual(serialized, self.EXPECTED_LLM_CHAIN_WITH_OCI_MD)
179179
llm_chain = load(serialized)
180180
self.assertIsInstance(llm_chain, LLMChain)
181181
self.assertIsInstance(llm_chain.prompt, PromptTemplate)
182182
self.assertEqual(llm_chain.prompt.template, self.PROMPT_TEMPLATE)
183-
self.assertIsInstance(llm_chain.llm, GenerativeAI)
184-
self.assertEqual(llm_chain.llm.compartment_id, self.COMPARTMENT_ID)
185-
self.assertEqual(llm_chain.llm.client_kwargs, self.GEN_AI_KWARGS)
183+
self.assertIsInstance(llm_chain.llm, ModelDeploymentVLLM)
184+
self.assertEqual(llm_chain.llm.endpoint, self.ENDPOINT)
185+
self.assertEqual(llm_chain.llm.model, "my_model")
186186
self.assertEqual(llm_chain.input_keys, ["subject"])
187187

188188
def test_oci_gen_ai_serialization(self):
189189
"""Tests serialization of OCI Gen AI LLM."""
190-
llm = GenerativeAI(
191-
compartment_id=self.COMPARTMENT_ID,
192-
client_kwargs=self.GEN_AI_KWARGS,
193-
)
190+
try:
191+
llm = GenerativeAI(
192+
compartment_id=self.COMPARTMENT_ID,
193+
client_kwargs=self.GEN_AI_KWARGS,
194+
)
195+
except ImportError:
196+
raise SkipTest("OCI SDK does not support Generative AI.")
194197
serialized = dump(llm)
195198
self.assertEqual(serialized, self.EXPECTED_GEN_AI_LLM)
196199
llm = load(serialized)
@@ -199,9 +202,12 @@ def test_oci_gen_ai_serialization(self):
199202

200203
def test_gen_ai_embeddings_serialization(self):
201204
"""Tests serialization of OCI Gen AI embeddings."""
202-
embeddings = GenerativeAIEmbeddings(
203-
compartment_id=self.COMPARTMENT_ID, client_kwargs=self.GEN_AI_KWARGS
204-
)
205+
try:
206+
embeddings = GenerativeAIEmbeddings(
207+
compartment_id=self.COMPARTMENT_ID, client_kwargs=self.GEN_AI_KWARGS
208+
)
209+
except ImportError:
210+
raise SkipTest("OCI SDK does not support Generative AI.")
205211
serialized = dump(embeddings)
206212
self.assertEqual(serialized, self.EXPECTED_GEN_AI_EMBEDDINGS)
207213
embeddings = load(serialized)

0 commit comments

Comments
 (0)