Skip to content

Commit 161b6f3

Browse files
authored
Merge branch 'main' into release/v2.13.0
2 parents fb67b7e + d50c68e commit 161b6f3

File tree

9 files changed

+259
-157
lines changed

9 files changed

+259
-157
lines changed

ads/model/artifact_downloader.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8; -*-
32

43
# Copyright (c) 2022, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
@@ -12,9 +11,9 @@
1211
from zipfile import ZipFile
1312

1413
from ads.common import utils
14+
from ads.common.object_storage_details import ObjectStorageDetails
1515
from ads.common.utils import extract_region
1616
from ads.model.service.oci_datascience_model import OCIDataScienceModel
17-
from ads.common.object_storage_details import ObjectStorageDetails
1817

1918

2019
class ArtifactDownloader(ABC):
@@ -169,9 +168,9 @@ def __init__(
169168

170169
def _download(self):
171170
"""Downloads model artifacts."""
172-
self.progress.update(f"Importing model artifacts from catalog")
171+
self.progress.update("Importing model artifacts from catalog")
173172

174-
if self.dsc_model.is_model_by_reference() and self.model_file_description:
173+
if self.dsc_model._is_model_by_reference() and self.model_file_description:
175174
self.download_from_model_file_description()
176175
self.progress.update()
177176
return

ads/model/datascience_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1778,7 +1778,7 @@ def _update_from_oci_dsc_model(
17781778
artifact_info = self.dsc_model.get_artifact_info()
17791779
_, file_name_info = cgi.parse_header(artifact_info["Content-Disposition"])
17801780

1781-
if self.dsc_model.is_model_by_reference():
1781+
if self.dsc_model._is_model_by_reference():
17821782
_, file_extension = os.path.splitext(file_name_info["filename"])
17831783
if file_extension.lower() == ".json":
17841784
bucket_uri, _ = self._download_file_description_artifact()

ads/model/model_metadata.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,7 +1509,10 @@ def _from_oci_metadata(cls, metadata_list):
15091509
metadata = cls()
15101510
for oci_item in metadata_list:
15111511
item = ModelTaxonomyMetadataItem._from_oci_metadata(oci_item)
1512-
metadata[item.key].update(value=item.value)
1512+
if item.key in metadata.keys:
1513+
metadata[item.key].update(value=item.value)
1514+
else:
1515+
metadata._items.add(item)
15131516
return metadata
15141517

15151518
def to_dataframe(self) -> pd.DataFrame:
@@ -1562,7 +1565,10 @@ def from_dict(cls, data: Dict) -> "ModelTaxonomyMetadata":
15621565
metadata = cls()
15631566
for item in data["data"]:
15641567
item = ModelTaxonomyMetadataItem.from_dict(item)
1565-
metadata[item.key].update(value=item.value)
1568+
if item.key in metadata.keys:
1569+
metadata[item.key].update(value=item.value)
1570+
else:
1571+
metadata._items.add(item)
15661572
return metadata
15671573

15681574

ads/model/service/oci_datascience_model.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,32 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8; -*-
32

43
# Copyright (c) 2022, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76
import logging
8-
import time
97
from functools import wraps
108
from io import BytesIO
119
from typing import Callable, Dict, List, Optional
1210

1311
import oci.data_science
14-
from ads.common import utils
15-
from ads.common.object_storage_details import ObjectStorageDetails
16-
from ads.common.oci_datascience import OCIDataScienceMixin
17-
from ads.common.oci_mixin import OCIWorkRequestMixin
18-
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
19-
from ads.common.utils import extract_region
20-
from ads.common.work_request import DataScienceWorkRequest
21-
from ads.model.deployment import ModelDeployment
2212
from oci.data_science.models import (
2313
ArtifactExportDetailsObjectStorage,
2414
ArtifactImportDetailsObjectStorage,
2515
CreateModelDetails,
2616
ExportModelArtifactDetails,
2717
ImportModelArtifactDetails,
2818
UpdateModelDetails,
29-
WorkRequest,
3019
)
3120
from oci.exceptions import ServiceError
3221

22+
from ads.common.object_storage_details import ObjectStorageDetails
23+
from ads.common.oci_datascience import OCIDataScienceMixin
24+
from ads.common.oci_mixin import OCIWorkRequestMixin
25+
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
26+
from ads.common.utils import extract_region
27+
from ads.common.work_request import DataScienceWorkRequest
28+
from ads.model.deployment import ModelDeployment
29+
3330
logger = logging.getLogger(__name__)
3431

3532
_REQUEST_INTERVAL_IN_SEC = 3
@@ -282,7 +279,7 @@ def get_artifact_info(self) -> Dict:
282279
msg="Model needs to be restored before the archived artifact content can be accessed."
283280
)
284281
def restore_archived_model_artifact(
285-
self, restore_model_for_hours_specified: Optional[int] = None
282+
self, restore_model_for_hours_specified: Optional[int] = None
286283
) -> None:
287284
"""Restores the archived model artifact.
288285
@@ -304,7 +301,8 @@ def restore_archived_model_artifact(
304301
"""
305302
return self.client.restore_archived_model_artifact(
306303
model_id=self.id,
307-
restore_model_for_hours_specified=restore_model_for_hours_specified).headers["opc-work-request-id"]
304+
restore_model_for_hours_specified=restore_model_for_hours_specified,
305+
).headers["opc-work-request-id"]
308306

309307
@check_for_model_id(
310308
msg="Model needs to be saved to the Model Catalog before the artifact content can be read."
@@ -581,7 +579,7 @@ def from_id(cls, ocid: str) -> "OCIDataScienceModel":
581579
raise ValueError("Model OCID not provided.")
582580
return super().from_ocid(ocid)
583581

584-
def is_model_by_reference(self):
582+
def _is_model_by_reference(self):
585583
"""Checks if model is created by reference
586584
Returns
587585
-------

tests/unitary/default_setup/model/test_artifact_downloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def test_download_large_artifact_from_model_file_description(self, mock_download
189189
"""Tests whether model_file_description is loaded within downloader and is parsed, and also if
190190
# download_from_model_file_description is appropriately called."""
191191

192-
self.mock_dsc_model.is_model_by_reference.return_value = True
192+
self.mock_dsc_model._is_model_by_reference.return_value = True
193193
self.mock_artifact_file_path = os.path.join(
194194
self.curr_dir, "test_files/model_description.json"
195195
)

tests/unitary/default_setup/model/test_datascience_model.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
ModelArtifactSizeError,
2727
BucketNotVersionedError,
2828
ModelFileDescriptionError,
29-
InvalidArtifactType, ModelRetentionSetting, ModelBackupSetting,
29+
InvalidArtifactType,
30+
ModelRetentionSetting,
31+
ModelBackupSetting,
3032
)
3133
from ads.model.model_metadata import (
3234
ModelCustomMetadata,
@@ -44,7 +46,7 @@
4446
from ads.config import AQUA_SERVICE_MODELS_BUCKET as SERVICE_MODELS_BUCKET
4547

4648
MODEL_OCID = "ocid1.datasciencemodel.oc1.iad.<unique_ocid>"
47-
49+
4850
OCI_MODEL_PAYLOAD = {
4951
"id": MODEL_OCID,
5052
"compartment_id": "ocid1.compartment.oc1..<unique_ocid>",
@@ -71,16 +73,17 @@
7173
{"key": "UseCaseType", "value": "multinomial_classification"},
7274
{"key": "Hyperparameters"},
7375
{"key": "ArtifactTestResults"},
76+
{"key": "UnexpectedKey", "value": "unexpected_value"},
7477
],
7578
"backup_setting": {
7679
"is_backup_enabled": True,
7780
"backup_region": "us-phoenix-1",
78-
"customer_notification_type": "ALL"
81+
"customer_notification_type": "ALL",
7982
},
8083
"retention_setting": {
8184
"archive_after_days": 30,
8285
"delete_after_days": 90,
83-
"customer_notification_type": "ALL"
86+
"customer_notification_type": "ALL",
8487
},
8588
"input_schema": '{"schema": [{"dtype": "int64", "feature_type": "Integer", "name": 0, "domain": {"values": "", "stats": {}, "constraints": []}, "required": true, "description": "0", "order": 0}], "version": "1.1"}',
8689
"output_schema": '{"schema": [{"dtype": "int64", "feature_type": "Integer", "name": 0, "domain": {"values": "", "stats": {}, "constraints": []}, "required": true, "description": "0", "order": 0}], "version": "1.1"}',
@@ -148,6 +151,7 @@
148151
{"key": "UseCaseType", "value": "multinomial_classification"},
149152
{"key": "Hyperparameters", "value": None},
150153
{"key": "ArtifactTestResults", "value": None},
154+
{"key": "UnexpectedKey", "value": "unexpected_value"},
151155
]
152156
},
153157
"provenanceMetadata": {
@@ -161,12 +165,12 @@
161165
"backupSetting": {
162166
"is_backup_enabled": True,
163167
"backup_region": "us-phoenix-1",
164-
"customer_notification_type": "ALL"
168+
"customer_notification_type": "ALL",
165169
},
166170
"retentionSetting": {
167171
"archive_after_days": 30,
168172
"delete_after_days": 90,
169-
"customer_notification_type": "ALL"
173+
"customer_notification_type": "ALL",
170174
},
171175
"artifact": "ocid1.datasciencemodel.oc1.iad.<unique_ocid>.zip",
172176
}
@@ -327,8 +331,8 @@ def test_with_methods_1(self, mock_load_default_properties):
327331
.with_defined_metadata_list(self.payload["definedMetadataList"])
328332
.with_provenance_metadata(self.payload["provenanceMetadata"])
329333
.with_artifact(self.payload["artifact"])
330-
.with_backup_setting(self.payload['backupSetting'])
331-
.with_retention_setting(self.payload['retentionSetting'])
334+
.with_backup_setting(self.payload["backupSetting"])
335+
.with_retention_setting(self.payload["retentionSetting"])
332336
)
333337
assert self.prepare_dict(dsc_model.to_dict()["spec"]) == self.prepare_dict(
334338
self.payload
@@ -356,8 +360,12 @@ def test_with_methods_2(self):
356360
ModelProvenanceMetadata.from_dict(self.payload["provenanceMetadata"])
357361
)
358362
.with_artifact(self.payload["artifact"])
359-
.with_backup_setting(ModelBackupSetting.from_dict(self.payload['backupSetting']))
360-
.with_retention_setting(ModelRetentionSetting.from_dict(self.payload['retentionSetting']))
363+
.with_backup_setting(
364+
ModelBackupSetting.from_dict(self.payload["backupSetting"])
365+
)
366+
.with_retention_setting(
367+
ModelRetentionSetting.from_dict(self.payload["retentionSetting"])
368+
)
361369
)
362370
assert self.prepare_dict(dsc_model.to_dict()["spec"]) == self.prepare_dict(
363371
self.payload
@@ -605,7 +613,7 @@ def test__to_oci_dsc_model(self):
605613
True,
606614
],
607615
)
608-
@patch.object(OCIDataScienceModel, "is_model_by_reference")
616+
@patch.object(OCIDataScienceModel, "_is_model_by_reference")
609617
@patch.object(OCIDataScienceModel, "get_artifact_info")
610618
@patch.object(OCIDataScienceModel, "get_model_provenance")
611619
@patch.object(DataScienceModel, "_download_file_description_artifact")

tests/unitary/default_setup/model/test_model_artifact.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pickle
88
import sys
99

10-
import mock
10+
from unittest import mock
1111
import pytest
1212
import yaml
1313
from ads.common.model import ADSModel

0 commit comments

Comments
 (0)