Skip to content

Commit 2ac2a3d

Browse files
committed
add more tests
1 parent 0f85ae2 commit 2ac2a3d

File tree

1 file changed

+40
-2
lines changed

1 file changed

+40
-2
lines changed

tests/unitary/with_extras/opctl/test_opctl_local_model_deployment_backend.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66

7-
from mock import ANY, patch, MagicMock, PropertyMock
7+
from mock import ANY, patch, MagicMock
88
import pytest
9-
from ads.opctl.backend.local import LocalModelDeploymentBackend, ModelCustomMetadata, os, create_signer, OCIClientFactory
9+
from ads.opctl.backend.local import LocalModelDeploymentBackend, ModelCustomMetadata, os, create_signer, OCIClientFactory, is_in_notebook_session, run_command
1010

1111

1212

@@ -155,3 +155,41 @@ def test_predict_download(
155155
mock__run_with_conda_pack.assert_called_once_with(
156156
{os.path.expanduser('~/.oci'): {'bind': '/home/datascience/.oci'}, os.path.expanduser('~/.ads_ops/models/fake_id'): {'bind': '/opt/ds/model/deployed_model/'}}, "--artifact-directory /opt/ds/model/deployed_model/ --payload 'fake_data' --auth api_key --profile DEFAULT", install=True, conda_uri='fake_path'
157157
)
158+
@patch("ads.opctl.backend.local.is_in_notebook_session", return_value=True)
159+
@patch("ads.opctl.backend.local.os.listdir", return_value=["path"])
160+
@patch("ads.opctl.backend.local.os.path.exists", return_value=False)
161+
def test_predict_download_in_notebook(
162+
self, mock_path_exists, mock_list_dir, mock_is_in_notebook
163+
):
164+
with patch("ads.opctl.backend.local._download_model") as mock__download:
165+
with patch.object(
166+
LocalModelDeploymentBackend,
167+
"_get_conda_info_from_custom_metadata",
168+
return_value=("fake_slug", "fake_path"),
169+
):
170+
with patch.object(
171+
LocalModelDeploymentBackend, "_get_conda_info_from_runtime"
172+
):
173+
with patch(
174+
"ads.opctl.backend.local.run_command"
175+
) as mock_run_command:
176+
backend = LocalModelDeploymentBackend(self.config)
177+
178+
backend.predict()
179+
mock__download.assert_called_once_with(
180+
auth='api_key',
181+
profile='DEFAULT',
182+
ocid="fake_id",
183+
artifact_directory=os.path.expanduser(
184+
"~/.ads_ops/models/fake_id"
185+
),
186+
region=None,
187+
bucket_uri="fake_bucket",
188+
timeout=None,
189+
force_overwrite=True,
190+
)
191+
dir_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../../"))
192+
script_path = os.path.join(dir_path, "ads/opctl/backend/../script.py")
193+
mock_run_command.assert_called_once_with(
194+
cmd=f"python {script_path} --artifact-directory /Users/ziye/.ads_ops/models/fake_id --payload 'fake_data' --auth api_key --profile DEFAULT", shell=True
195+
)

0 commit comments

Comments
 (0)