|
| 1 | +import os |
| 2 | +import tempfile |
| 3 | +from unittest.mock import MagicMock, patch |
| 4 | + |
| 5 | +from ads.llm.deploy import ChainDeployment |
| 6 | + |
| 7 | +from langchain.chains import LLMChain |
| 8 | +from langchain.prompts import PromptTemplate |
| 9 | +from tests.unitary.with_extras.langchain.test_guardrails import FakeLLM |
| 10 | + |
| 11 | +class TestLangChainDeploy: |
| 12 | + |
| 13 | + def generate_chain_application(self): |
| 14 | + prompt = PromptTemplate.from_template("Tell me a joke about {subject}") |
| 15 | + llm_chain = LLMChain(prompt=prompt, llm=FakeLLM(), verbose=True) |
| 16 | + return llm_chain |
| 17 | + |
| 18 | + @patch("ads.model.datascience_model.DataScienceModel._to_oci_dsc_model") |
| 19 | + def test_initialize(self, mock_to_oci_dsc_model): |
| 20 | + chain_application = self.generate_chain_application() |
| 21 | + chain_deployment = ChainDeployment(chain_application, auth=MagicMock()) |
| 22 | + mock_to_oci_dsc_model.assert_called() |
| 23 | + assert isinstance(chain_deployment.chain, LLMChain) |
| 24 | + |
| 25 | + @patch("ads.model.runtime.env_info.get_service_packs") |
| 26 | + @patch("ads.common.auth.default_signer") |
| 27 | + @patch("ads.model.datascience_model.DataScienceModel._to_oci_dsc_model") |
| 28 | + def test_prepare(self, mock_to_oci_dsc_model, mock_default_signer, mock_get_service_packs): |
| 29 | + mock_default_signer.return_value = MagicMock() |
| 30 | + inference_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1" |
| 31 | + inference_python_version="3.7" |
| 32 | + mock_get_service_packs.return_value = ( |
| 33 | + { |
| 34 | + inference_conda_env : ("mlcpuv1", inference_python_version), |
| 35 | + }, |
| 36 | + { |
| 37 | + "mlcpuv1" : (inference_conda_env, inference_python_version), |
| 38 | + } |
| 39 | + ) |
| 40 | + artifact_dir = tempfile.mkdtemp() |
| 41 | + chain_application = self.generate_chain_application() |
| 42 | + chain_deployment = ChainDeployment( |
| 43 | + chain_application, |
| 44 | + artifact_dir=artifact_dir, |
| 45 | + auth=MagicMock() |
| 46 | + ) |
| 47 | + |
| 48 | + mock_to_oci_dsc_model.assert_called() |
| 49 | + mock_default_signer.assert_called() |
| 50 | + |
| 51 | + chain_deployment.prepare( |
| 52 | + inference_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1", |
| 53 | + inference_python_version="3.7" |
| 54 | + ) |
| 55 | + |
| 56 | + mock_get_service_packs.assert_called() |
| 57 | + |
| 58 | + score_py_file_location = os.path.join(chain_deployment.artifact_dir, "score.py") |
| 59 | + chain_yaml_file_location = os.path.join(chain_deployment.artifact_dir, "chain.yaml") |
| 60 | + runtime_yaml_file_location = os.path.join(chain_deployment.artifact_dir, "runtime.yaml") |
| 61 | + assert os.path.isfile(score_py_file_location) and os.path.getsize(score_py_file_location) > 0 |
| 62 | + assert os.path.isfile(chain_yaml_file_location) and os.path.getsize(chain_yaml_file_location) > 0 |
| 63 | + assert os.path.isfile(runtime_yaml_file_location) and os.path.getsize(runtime_yaml_file_location) > 0 |
0 commit comments