Skip to content

Commit 9a387b2

Browse files
committed
fix the unit test
1 parent 33d34b1 commit 9a387b2

File tree

3 files changed

+16
-34
lines changed

3 files changed

+16
-34
lines changed

ads/opctl/model/cmds.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,14 @@ def download_model(**kwargs):
4949

5050

5151
def _download_model(
52-
ocid, artifact_directory, region, bucket_uri, timeout, force_overwrite, auth, profile
52+
ocid, artifact_directory, region, bucket_uri, timeout, force_overwrite, auth, profile=None
5353
):
5454
os.makedirs(artifact_directory, exist_ok=True)
55-
55+
kwargs = {"auth": auth}
56+
if profile:
57+
kwargs["profile"] = profile
5658
try:
57-
with AuthContext(auth=auth, profile=profile):
59+
with AuthContext(**kwargs):
5860
dsc_model = DataScienceModel.from_id(ocid)
5961
dsc_model.download_artifact(
6062
target_dir=artifact_directory,

tests/unitary/with_extras/opctl/test_opctl_local_model_deployment_backend.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -115,17 +115,7 @@ def test_predict(self, mock_path_exists, mock_list_dir):
115115
backend.predict()
116116
mock__download.assert_not_called()
117117
mock__run_with_conda_pack.assert_called_once_with(
118-
{
119-
os.path.expanduser("~/.oci"): {
120-
"bind": "/home/datascience/.oci"
121-
},
122-
os.path.expanduser("~/.ads_ops/models/fake_id"): {
123-
"bind": "/opt/ds/model/deployed_model/"
124-
},
125-
},
126-
"/opt/ds/model/deployed_model/ fake_data",
127-
install=True,
128-
conda_uri="fake_path",
118+
{os.path.expanduser('~/.oci'): {'bind': '/home/datascience/.oci'}, os.path.expanduser('~/.ads_ops/models/fake_id'): {'bind': '/opt/ds/model/deployed_model/'}}, "--artifact-directory /opt/ds/model/deployed_model/ --payload 'fake_data' --auth api_key --profile DEFAULT", install=True, conda_uri='fake_path'
129119
)
130120

131121
@patch("ads.opctl.backend.local.os.listdir", return_value=["path"])
@@ -151,7 +141,8 @@ def test_predict_download(
151141

152142
backend.predict()
153143
mock__download.assert_called_once_with(
154-
oci_auth=backend.oci_auth,
144+
auth='api_key',
145+
profile='DEFAULT',
155146
ocid="fake_id",
156147
artifact_directory=os.path.expanduser(
157148
"~/.ads_ops/models/fake_id"
@@ -162,15 +153,5 @@ def test_predict_download(
162153
force_overwrite=True,
163154
)
164155
mock__run_with_conda_pack.assert_called_once_with(
165-
{
166-
os.path.expanduser("~/.oci"): {
167-
"bind": "/home/datascience/.oci"
168-
},
169-
os.path.expanduser("~/.ads_ops/models/fake_id"): {
170-
"bind": "/opt/ds/model/deployed_model/"
171-
},
172-
},
173-
"/opt/ds/model/deployed_model/ fake_data",
174-
install=True,
175-
conda_uri="fake_path",
156+
{os.path.expanduser('~/.oci'): {'bind': '/home/datascience/.oci'}, os.path.expanduser('~/.ads_ops/models/fake_id'): {'bind': '/opt/ds/model/deployed_model/'}}, "--artifact-directory /opt/ds/model/deployed_model/ --payload 'fake_data' --auth api_key --profile DEFAULT", install=True, conda_uri='fake_path'
176157
)

tests/unitary/with_extras/opctl/test_opctl_model.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,32 +11,30 @@ def test_model__download_model(mock_from_id):
1111
mock_datascience_model = MagicMock()
1212
mock_from_id.return_value = mock_datascience_model
1313
_download_model(
14-
"fake_model_id", "fake_dir", "fake_auth", "region", "bucket_uri", 36, False
14+
"fake_model_id", "fake_dir", "region", "bucket_uri", 36, False, "api_key",
1515
)
1616
mock_from_id.assert_called_with("fake_model_id")
1717
mock_datascience_model.download_artifact.assert_called_with(
1818
target_dir="fake_dir",
1919
force_overwrite=False,
2020
overwrite_existing_artifact=True,
2121
remove_existing_artifact=True,
22-
auth="fake_auth",
2322
region="region",
2423
timeout=36,
2524
bucket_uri="bucket_uri",
2625
)
2726

2827

29-
@patch.object(DataScienceModel, "from_id", side_effect=Exception("Fake error."))
28+
@patch.object(DataScienceModel, "from_id", side_effect=Exception())
3029
def test_model__download_model_error(mock_from_id):
31-
with pytest.raises(Exception, match="Fake error."):
30+
with pytest.raises(Exception):
3231
_download_model(
33-
"fake_model_id", "fake_dir", "fake_auth", "region", "bucket_uri", 36, False
32+
"fake_model_id", "fake_dir", "region", "bucket_uri", 36, False, "api_key",
3433
)
3534

3635

3736
@patch("ads.opctl.model.cmds._download_model")
38-
def test_download_model( mock__download_model):
39-
auth_mock = MagicMock()
37+
def test_download_model(mock__download_model):
4038
download_model(ocid="fake_model_id")
4139
mock__download_model.assert_called_once_with(
4240
ocid="fake_model_id",
@@ -45,5 +43,6 @@ def test_download_model( mock__download_model):
4543
bucket_uri=None,
4644
timeout=None,
4745
force_overwrite=False,
48-
oci_auth=auth_mock,
46+
auth='api_key',
47+
profile='DEFAULT'
4948
)

0 commit comments

Comments
 (0)