@@ -1280,6 +1280,7 @@ def test_from_id_fail(self):
1280
1280
1281
1281
@patch .object (GenericModel , "from_model_catalog" )
1282
1282
def test_from_id_model_without_artifact (self , mock_from_model_catalog ):
1283
+ """Test to check model artifact is not loaded when load_artifact is set to False"""
1283
1284
test_model_id = "xxxx.datasciencemodel.xxxx"
1284
1285
mock_model = MagicMock (model_id = test_model_id , model_artifact = None )
1285
1286
mock_from_model_catalog .return_value = mock_model
@@ -1307,6 +1308,7 @@ def test_from_id_model_without_artifact(self, mock_from_model_catalog):
1307
1308
1308
1309
@patch .object (GenericModel , "from_model_catalog" )
1309
1310
def test_from_id_with_artifact (self , mock_from_model_catalog ):
1311
+ """Test to check model artifact is loaded when load_artifact is set to True"""
1310
1312
test_model_id = "xxxx.datasciencemodel.xxxx"
1311
1313
artifact_dir = "test_dir"
1312
1314
model_artifact = MagicMock (artifact_dir = artifact_dir , reload = False )
@@ -1359,6 +1361,42 @@ def test_download_artifact_fail(self):
1359
1361
generic_model = GenericModel ()
1360
1362
generic_model .download_artifact (uri = "" , auth = {"config" : "value" })
1361
1363
1364
+ @patch .object (ModelArtifact , "from_uri" )
1365
+ @patch .object (DataScienceModel , "from_id" )
1366
+ @patch .object (GenericModel , "reload_runtime_info" )
1367
+ @patch .object (DataScienceModel , "download_artifact" )
1368
+ def test_download_artifact (
1369
+ self ,
1370
+ mock_download_artifact ,
1371
+ mock_reload_runtime_info ,
1372
+ mock_dsc_model_from_id ,
1373
+ mock_from_uri ,
1374
+ ):
1375
+ """Test to check if model artifacts are updated after download_artifact is called"""
1376
+ test_model_id = "xxxx.datasciencemodel.xxxx"
1377
+ artifact_dir = "test_dir"
1378
+ self .generic_model .dsc_model = MagicMock (id = test_model_id )
1379
+ self .generic_model .model_artifact = None
1380
+ self .generic_model .artifact_dir = artifact_dir
1381
+
1382
+ mock_dsc_model_from_id .return_value = MagicMock (id = test_model_id )
1383
+ mock_download_artifact .return_value = None
1384
+ mock_artifact_instance = MagicMock (model = "test_model" )
1385
+ mock_from_uri .return_value = mock_artifact_instance
1386
+
1387
+ assert self .generic_model .model_artifact is None
1388
+
1389
+ self .generic_model .download_artifact (
1390
+ artifact_dir = artifact_dir ,
1391
+ auth = {"config" : {}},
1392
+ force_overwrite = True ,
1393
+ bucket_uri = "bucket_uri" ,
1394
+ remove_existing_artifact = True ,
1395
+ )
1396
+
1397
+ mock_reload_runtime_info .assert_called ()
1398
+ assert self .generic_model .model_artifact is not None
1399
+
1362
1400
def test_save_without_local_artifact (self ):
1363
1401
"""Test to check if model artifact is available before saving the model"""
1364
1402
self .generic_model .model_artifact = None
0 commit comments