Skip to content

Commit ac8e8fc

Browse files
committed
Update tests.
1 parent 85b6d88 commit ac8e8fc

File tree

2 files changed

+35
-19
lines changed

2 files changed

+35
-19
lines changed

tests/unitary/with_extras/langchain/test_deploy.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
from langchain.prompts import PromptTemplate
99
from tests.unitary.with_extras.langchain.test_guardrails import FakeLLM
1010

11+
1112
class TestLangChainDeploy:
12-
1313
def generate_chain_application(self):
1414
prompt = PromptTemplate.from_template("Tell me a joke about {subject}")
1515
llm_chain = LLMChain(prompt=prompt, llm=FakeLLM(), verbose=True)
1616
return llm_chain
17-
17+
1818
@patch("ads.model.datascience_model.DataScienceModel._to_oci_dsc_model")
1919
def test_initialize(self, mock_to_oci_dsc_model):
2020
chain_application = self.generate_chain_application()
@@ -25,39 +25,51 @@ def test_initialize(self, mock_to_oci_dsc_model):
2525
@patch("ads.model.runtime.env_info.get_service_packs")
2626
@patch("ads.common.auth.default_signer")
2727
@patch("ads.model.datascience_model.DataScienceModel._to_oci_dsc_model")
28-
def test_prepare(self, mock_to_oci_dsc_model, mock_default_signer, mock_get_service_packs):
28+
def test_prepare(
29+
self, mock_to_oci_dsc_model, mock_default_signer, mock_get_service_packs
30+
):
2931
mock_default_signer.return_value = MagicMock()
30-
inference_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1"
31-
inference_python_version="3.7"
32+
inference_conda_env = "oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1"
33+
inference_python_version = "3.7"
3234
mock_get_service_packs.return_value = (
3335
{
34-
inference_conda_env : ("mlcpuv1", inference_python_version),
36+
inference_conda_env: ("mlcpuv1", inference_python_version),
3537
},
3638
{
37-
"mlcpuv1" : (inference_conda_env, inference_python_version),
38-
}
39+
"mlcpuv1": (inference_conda_env, inference_python_version),
40+
},
3941
)
4042
artifact_dir = tempfile.mkdtemp()
4143
chain_application = self.generate_chain_application()
4244
chain_deployment = ChainDeployment(
43-
chain_application,
44-
artifact_dir=artifact_dir,
45-
auth=MagicMock()
45+
chain_application, artifact_dir=artifact_dir, auth=MagicMock()
4646
)
47-
47+
4848
mock_to_oci_dsc_model.assert_called()
49-
mock_default_signer.assert_called()
5049

5150
chain_deployment.prepare(
5251
inference_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1",
53-
inference_python_version="3.7"
52+
inference_python_version="3.7",
5453
)
5554

5655
mock_get_service_packs.assert_called()
5756

5857
score_py_file_location = os.path.join(chain_deployment.artifact_dir, "score.py")
59-
chain_yaml_file_location = os.path.join(chain_deployment.artifact_dir, "chain.yaml")
60-
runtime_yaml_file_location = os.path.join(chain_deployment.artifact_dir, "runtime.yaml")
61-
assert os.path.isfile(score_py_file_location) and os.path.getsize(score_py_file_location) > 0
62-
assert os.path.isfile(chain_yaml_file_location) and os.path.getsize(chain_yaml_file_location) > 0
63-
assert os.path.isfile(runtime_yaml_file_location) and os.path.getsize(runtime_yaml_file_location) > 0
58+
chain_yaml_file_location = os.path.join(
59+
chain_deployment.artifact_dir, "chain.yaml"
60+
)
61+
runtime_yaml_file_location = os.path.join(
62+
chain_deployment.artifact_dir, "runtime.yaml"
63+
)
64+
assert (
65+
os.path.isfile(score_py_file_location)
66+
and os.path.getsize(score_py_file_location) > 0
67+
)
68+
assert (
69+
os.path.isfile(chain_yaml_file_location)
70+
and os.path.getsize(chain_yaml_file_location) > 0
71+
)
72+
assert (
73+
os.path.isfile(runtime_yaml_file_location)
74+
and os.path.getsize(runtime_yaml_file_location) > 0
75+
)

tests/unitary/with_extras/langchain/test_serialization.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
class ChainSerializationTest(TestCase):
2626
"""Contains tests for chain serialization."""
2727

28+
def setUp(self) -> None:
29+
self.maxDiff = None
30+
return super().setUp()
31+
2832
PROMPT_TEMPLATE = "Tell me a joke about {subject}"
2933
COMPARTMENT_ID = "<ocid>"
3034
GEN_AI_KWARGS = {"service_endpoint": "https://endpoint.oraclecloud.com"}

0 commit comments

Comments
 (0)