|
4 | 4 | # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5 | 5 |
|
6 | 6 |
|
7 |
| -from mock import ANY, patch, MagicMock, PropertyMock |
| 7 | +from mock import ANY, patch, MagicMock |
8 | 8 | import pytest
|
9 |
| -from ads.opctl.backend.local import LocalModelDeploymentBackend, ModelCustomMetadata, os, create_signer, OCIClientFactory |
| 9 | +from ads.opctl.backend.local import LocalModelDeploymentBackend, ModelCustomMetadata, os, create_signer, OCIClientFactory, is_in_notebook_session, run_command |
10 | 10 |
|
11 | 11 |
|
12 | 12 |
|
@@ -155,3 +155,41 @@ def test_predict_download(
|
155 | 155 | mock__run_with_conda_pack.assert_called_once_with(
|
156 | 156 | {os.path.expanduser('~/.oci'): {'bind': '/home/datascience/.oci'}, os.path.expanduser('~/.ads_ops/models/fake_id'): {'bind': '/opt/ds/model/deployed_model/'}}, "--artifact-directory /opt/ds/model/deployed_model/ --payload 'fake_data' --auth api_key --profile DEFAULT", install=True, conda_uri='fake_path'
|
157 | 157 | )
|
| 158 | + @patch("ads.opctl.backend.local.is_in_notebook_session", return_value=True) |
| 159 | + @patch("ads.opctl.backend.local.os.listdir", return_value=["path"]) |
| 160 | + @patch("ads.opctl.backend.local.os.path.exists", return_value=False) |
| 161 | + def test_predict_download_in_notebook( |
| 162 | + self, mock_path_exists, mock_list_dir, mock_is_in_notebook |
| 163 | + ): |
| 164 | + with patch("ads.opctl.backend.local._download_model") as mock__download: |
| 165 | + with patch.object( |
| 166 | + LocalModelDeploymentBackend, |
| 167 | + "_get_conda_info_from_custom_metadata", |
| 168 | + return_value=("fake_slug", "fake_path"), |
| 169 | + ): |
| 170 | + with patch.object( |
| 171 | + LocalModelDeploymentBackend, "_get_conda_info_from_runtime" |
| 172 | + ): |
| 173 | + with patch( |
| 174 | + "ads.opctl.backend.local.run_command" |
| 175 | + ) as mock_run_command: |
| 176 | + backend = LocalModelDeploymentBackend(self.config) |
| 177 | + |
| 178 | + backend.predict() |
| 179 | + mock__download.assert_called_once_with( |
| 180 | + auth='api_key', |
| 181 | + profile='DEFAULT', |
| 182 | + ocid="fake_id", |
| 183 | + artifact_directory=os.path.expanduser( |
| 184 | + "~/.ads_ops/models/fake_id" |
| 185 | + ), |
| 186 | + region=None, |
| 187 | + bucket_uri="fake_bucket", |
| 188 | + timeout=None, |
| 189 | + force_overwrite=True, |
| 190 | + ) |
| 191 | + dir_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../../")) |
| 192 | + script_path = os.path.join(dir_path, "ads/opctl/backend/../script.py") |
| 193 | + mock_run_command.assert_called_once_with( |
| 194 | + cmd=f"python {script_path} --artifact-directory /Users/ziye/.ads_ops/models/fake_id --payload 'fake_data' --auth api_key --profile DEFAULT", shell=True |
| 195 | + ) |
0 commit comments