|
40 | 40 | _ATTRIBUTES_TO_SHOW_,
|
41 | 41 | GenericModel,
|
42 | 42 | NotActiveDeploymentError,
|
| 43 | + ArtifactsNotAvailableError, |
43 | 44 | SummaryStatus,
|
44 | 45 | _prepare_artifact_dir,
|
45 | 46 | )
|
@@ -1277,6 +1278,97 @@ def test_from_id_fail(self):
|
1277 | 1278 | ignore_conda_error=True,
|
1278 | 1279 | )
|
1279 | 1280 |
|
| 1281 | + @patch.object(GenericModel, "from_model_catalog") |
| 1282 | + def test_from_id_model_without_artifact(self, mock_from_model_catalog): |
| 1283 | + test_model_id = "xxxx.datasciencemodel.xxxx" |
| 1284 | + mock_model = MagicMock(model_id=test_model_id, model_artifact=None) |
| 1285 | + mock_from_model_catalog.return_value = mock_model |
| 1286 | + |
| 1287 | + test_model_result = GenericModel.from_id( |
| 1288 | + ocid=test_model_id, |
| 1289 | + compartment_id="test_compartment_id", |
| 1290 | + load_artifact=False, |
| 1291 | + ) |
| 1292 | + mock_from_model_catalog.assert_called_with( |
| 1293 | + test_model_id, |
| 1294 | + model_file_name=None, |
| 1295 | + artifact_dir=None, |
| 1296 | + auth=None, |
| 1297 | + force_overwrite=False, |
| 1298 | + properties=None, |
| 1299 | + bucket_uri=None, |
| 1300 | + remove_existing_artifact=True, |
| 1301 | + compartment_id="test_compartment_id", |
| 1302 | + ignore_conda_error=False, |
| 1303 | + load_artifact=False, |
| 1304 | + ) |
| 1305 | + assert test_model_result.model_artifact is None |
| 1306 | + assert test_model_result == mock_model |
| 1307 | + |
| 1308 | + @patch.object(GenericModel, "from_model_catalog") |
| 1309 | + def test_from_id_with_artifact(self, mock_from_model_catalog): |
| 1310 | + test_model_id = "xxxx.datasciencemodel.xxxx" |
| 1311 | + artifact_dir = "test_dir" |
| 1312 | + model_artifact = MagicMock(artifact_dir=artifact_dir, reload=False) |
| 1313 | + mock_model = MagicMock(model_id=test_model_id, model_artifact=model_artifact) |
| 1314 | + mock_from_model_catalog.return_value = mock_model |
| 1315 | + |
| 1316 | + test_model_result = GenericModel.from_id( |
| 1317 | + artifact_dir=artifact_dir, |
| 1318 | + ocid=test_model_id, |
| 1319 | + compartment_id="test_compartment_id", |
| 1320 | + remove_existing_artifact=True, |
| 1321 | + load_artifact=True, |
| 1322 | + ) |
| 1323 | + |
| 1324 | + mock_from_model_catalog.assert_called_with( |
| 1325 | + test_model_id, |
| 1326 | + model_file_name=None, |
| 1327 | + artifact_dir=artifact_dir, |
| 1328 | + auth=None, |
| 1329 | + force_overwrite=False, |
| 1330 | + properties=None, |
| 1331 | + bucket_uri=None, |
| 1332 | + remove_existing_artifact=True, |
| 1333 | + compartment_id="test_compartment_id", |
| 1334 | + ignore_conda_error=False, |
| 1335 | + load_artifact=True, |
| 1336 | + ) |
| 1337 | + assert test_model_result.model_artifact is not None |
| 1338 | + assert test_model_result == mock_model |
| 1339 | + |
| 1340 | + @patch("ads.common.auth.default_signer") |
| 1341 | + def test_verify_without_local_artifact(self, mock_signer): |
| 1342 | + """Test verify input data with model without artifacts loaded.""" |
| 1343 | + _prepare(self.generic_model) |
| 1344 | + self.generic_model.model_artifact = None |
| 1345 | + |
| 1346 | + with pytest.raises( |
| 1347 | + ArtifactsNotAvailableError, |
| 1348 | + match="Model artifacts are either not generated or " |
| 1349 | + "not available locally.", |
| 1350 | + ): |
| 1351 | + self.generic_model.verify(self.X_test.tolist()) |
| 1352 | + |
| 1353 | + def test_download_artifact_fail(self): |
| 1354 | + """Test to check if model is loaded first before downloading artifacts""" |
| 1355 | + with pytest.raises( |
| 1356 | + ValueError, |
| 1357 | + match="`model_id` is not available, load the GenericModel object first.", |
| 1358 | + ): |
| 1359 | + generic_model = GenericModel() |
| 1360 | + generic_model.download_artifact(uri="", auth={"config": "value"}) |
| 1361 | + |
| 1362 | + def test_save_without_local_artifact(self): |
| 1363 | + """Test to check if model artifact is available before saving the model""" |
| 1364 | + self.generic_model.model_artifact = None |
| 1365 | + with pytest.raises( |
| 1366 | + ArtifactsNotAvailableError, |
| 1367 | + match="Model artifacts are either not generated or " |
| 1368 | + "not available locally.", |
| 1369 | + ): |
| 1370 | + self.generic_model.save() |
| 1371 | + |
1280 | 1372 | @patch.object(ModelDeployment, "activate")
|
1281 | 1373 | @patch.object(ModelDeployment, "deactivate")
|
1282 | 1374 | @patch("ads.common.auth.default_signer")
|
|
0 commit comments