Skip to content

Commit 2a591f7

Browse files
committed
Updated pr.
1 parent 2bbaef7 commit 2a591f7

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import os
2+
import tempfile
3+
from unittest.mock import MagicMock, patch
4+
5+
from ads.llm.deploy import ChainDeployment
6+
7+
from langchain.chains import LLMChain
8+
from langchain.prompts import PromptTemplate
9+
from tests.unitary.with_extras.langchain.test_guardrails import FakeLLM
10+
11+
class TestLangChainDeploy:
12+
13+
def generate_chain_application(self):
14+
prompt = PromptTemplate.from_template("Tell me a joke about {subject}")
15+
llm_chain = LLMChain(prompt=prompt, llm=FakeLLM(), verbose=True)
16+
return llm_chain
17+
18+
@patch("ads.model.datascience_model.DataScienceModel._to_oci_dsc_model")
19+
def test_initialize(self, mock_to_oci_dsc_model):
20+
chain_application = self.generate_chain_application()
21+
chain_deployment = ChainDeployment(chain_application, auth=MagicMock())
22+
mock_to_oci_dsc_model.assert_called()
23+
assert isinstance(chain_deployment.chain, LLMChain)
24+
25+
@patch("ads.model.runtime.env_info.get_service_packs")
26+
@patch("ads.common.auth.default_signer")
27+
@patch("ads.model.datascience_model.DataScienceModel._to_oci_dsc_model")
28+
def test_prepare(self, mock_to_oci_dsc_model, mock_default_signer, mock_get_service_packs):
29+
mock_default_signer.return_value = MagicMock()
30+
inference_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1"
31+
inference_python_version="3.7"
32+
mock_get_service_packs.return_value = (
33+
{
34+
inference_conda_env : ("mlcpuv1", inference_python_version),
35+
},
36+
{
37+
"mlcpuv1" : (inference_conda_env, inference_python_version),
38+
}
39+
)
40+
artifact_dir = tempfile.mkdtemp()
41+
chain_application = self.generate_chain_application()
42+
chain_deployment = ChainDeployment(
43+
chain_application,
44+
artifact_dir=artifact_dir,
45+
auth=MagicMock()
46+
)
47+
48+
mock_to_oci_dsc_model.assert_called()
49+
mock_default_signer.assert_called()
50+
51+
chain_deployment.prepare(
52+
inference_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1",
53+
inference_python_version="3.7"
54+
)
55+
56+
mock_get_service_packs.assert_called()
57+
58+
score_py_file_location = os.path.join(chain_deployment.artifact_dir, "score.py")
59+
chain_yaml_file_location = os.path.join(chain_deployment.artifact_dir, "chain.yaml")
60+
runtime_yaml_file_location = os.path.join(chain_deployment.artifact_dir, "runtime.yaml")
61+
assert os.path.isfile(score_py_file_location) and os.path.getsize(score_py_file_location) > 0
62+
assert os.path.isfile(chain_yaml_file_location) and os.path.getsize(chain_yaml_file_location) > 0
63+
assert os.path.isfile(runtime_yaml_file_location) and os.path.getsize(runtime_yaml_file_location) > 0

0 commit comments

Comments
 (0)