Skip to content

Commit 855fcc5

Browse files
added a unit test
1 parent f60de23 commit 855fcc5

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

tests/unitary/with_extras/model/test_generic_model.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,6 +1280,7 @@ def test_from_id_fail(self):
12801280

12811281
@patch.object(GenericModel, "from_model_catalog")
12821282
def test_from_id_model_without_artifact(self, mock_from_model_catalog):
1283+
"""Test to check model artifact is not loaded when load_artifact is set to False"""
12831284
test_model_id = "xxxx.datasciencemodel.xxxx"
12841285
mock_model = MagicMock(model_id=test_model_id, model_artifact=None)
12851286
mock_from_model_catalog.return_value = mock_model
@@ -1307,6 +1308,7 @@ def test_from_id_model_without_artifact(self, mock_from_model_catalog):
13071308

13081309
@patch.object(GenericModel, "from_model_catalog")
13091310
def test_from_id_with_artifact(self, mock_from_model_catalog):
1311+
"""Test to check model artifact is loaded when load_artifact is set to True"""
13101312
test_model_id = "xxxx.datasciencemodel.xxxx"
13111313
artifact_dir = "test_dir"
13121314
model_artifact = MagicMock(artifact_dir=artifact_dir, reload=False)
@@ -1359,6 +1361,42 @@ def test_download_artifact_fail(self):
13591361
generic_model = GenericModel()
13601362
generic_model.download_artifact(uri="", auth={"config": "value"})
13611363

1364+
@patch.object(ModelArtifact, "from_uri")
1365+
@patch.object(DataScienceModel, "from_id")
1366+
@patch.object(GenericModel, "reload_runtime_info")
1367+
@patch.object(DataScienceModel, "download_artifact")
1368+
def test_download_artifact(
1369+
self,
1370+
mock_download_artifact,
1371+
mock_reload_runtime_info,
1372+
mock_dsc_model_from_id,
1373+
mock_from_uri,
1374+
):
1375+
"""Test to check if model artifacts are updated after download_artifact is called"""
1376+
test_model_id = "xxxx.datasciencemodel.xxxx"
1377+
artifact_dir = "test_dir"
1378+
self.generic_model.dsc_model = MagicMock(id=test_model_id)
1379+
self.generic_model.model_artifact = None
1380+
self.generic_model.artifact_dir = artifact_dir
1381+
1382+
mock_dsc_model_from_id.return_value = MagicMock(id=test_model_id)
1383+
mock_download_artifact.return_value = None
1384+
mock_artifact_instance = MagicMock(model="test_model")
1385+
mock_from_uri.return_value = mock_artifact_instance
1386+
1387+
assert self.generic_model.model_artifact is None
1388+
1389+
self.generic_model.download_artifact(
1390+
artifact_dir=artifact_dir,
1391+
auth={"config": {}},
1392+
force_overwrite=True,
1393+
bucket_uri="bucket_uri",
1394+
remove_existing_artifact=True,
1395+
)
1396+
1397+
mock_reload_runtime_info.assert_called()
1398+
assert self.generic_model.model_artifact is not None
1399+
13621400
def test_save_without_local_artifact(self):
13631401
"""Test to check if model artifact is available before saving the model"""
13641402
self.generic_model.model_artifact = None

0 commit comments

Comments
 (0)