Skip to content

Commit b9d923d

Browse files
committed
updated with formatter changes
1 parent cf65397 commit b9d923d

File tree

1 file changed

+59
-62
lines changed

1 file changed

+59
-62
lines changed

ads/model/datascience_model.py

Lines changed: 59 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,10 @@
2020
from ads.common import utils
2121
from ads.common.extended_enum import ExtendedEnumMeta
2222
from ads.common.object_storage_details import ObjectStorageDetails
23-
from ads.config import (
24-
AQUA_SERVICE_MODELS_BUCKET as SERVICE_MODELS_BUCKET,
25-
)
2623
from ads.config import (
2724
COMPARTMENT_OCID,
2825
PROJECT_OCID,
26+
AQUA_SERVICE_MODELS_BUCKET as SERVICE_MODELS_BUCKET,
2927
)
3028
from ads.feature_engineering.schema import Schema
3129
from ads.jobs.builders.base import Builder
@@ -48,6 +46,7 @@
4846

4947
logger = logging.getLogger(__name__)
5048

49+
5150
_MAX_ARTIFACT_SIZE_IN_BYTES = 2147483648 # 2GB
5251
MODEL_BY_REFERENCE_VERSION = "1.0"
5352
MODEL_BY_REFERENCE_JSON_FILE_NAME = "model_description.json"
@@ -65,8 +64,8 @@ def __init__(self, max_artifact_size: str):
6564

6665
class BucketNotVersionedError(Exception): # pragma: no cover
6766
def __init__(
68-
self,
69-
msg="Model artifact bucket is not versioned. Enable versioning on the bucket to proceed with model creation by reference.",
67+
self,
68+
msg="Model artifact bucket is not versioned. Enable versioning on the bucket to proceed with model creation by reference.",
7069
):
7170
super().__init__(msg)
7271

@@ -527,6 +526,7 @@ class DataScienceModel(Builder):
527526
Sets path details for models created by reference. Input can be either a dict, string or json file and
528527
the schema is dictated by model_file_description_schema.json
529528
529+
530530
Examples
531531
--------
532532
>>> ds_model = (DataScienceModel()
@@ -796,7 +796,7 @@ def defined_tags(self) -> Dict[str, Dict[str, object]]:
796796
return self.get_spec(self.CONST_DEFINED_TAG)
797797

798798
def with_defined_tags(
799-
self, **kwargs: Dict[str, Dict[str, object]]
799+
self, **kwargs: Dict[str, Dict[str, object]]
800800
) -> "DataScienceModel":
801801
"""Sets defined tags.
802802
@@ -877,7 +877,7 @@ def defined_metadata_list(self) -> ModelTaxonomyMetadata:
877877
return self.get_spec(self.CONST_DEFINED_METADATA)
878878

879879
def with_defined_metadata_list(
880-
self, metadata: Union[ModelTaxonomyMetadata, Dict]
880+
self, metadata: Union[ModelTaxonomyMetadata, Dict]
881881
) -> "DataScienceModel":
882882
"""Sets model taxonomy (defined) metadata.
883883
@@ -901,7 +901,7 @@ def custom_metadata_list(self) -> ModelCustomMetadata:
901901
return self.get_spec(self.CONST_CUSTOM_METADATA)
902902

903903
def with_custom_metadata_list(
904-
self, metadata: Union[ModelCustomMetadata, Dict]
904+
self, metadata: Union[ModelCustomMetadata, Dict]
905905
) -> "DataScienceModel":
906906
"""Sets model custom metadata.
907907
@@ -925,7 +925,7 @@ def provenance_metadata(self) -> ModelProvenanceMetadata:
925925
return self.get_spec(self.CONST_PROVENANCE_METADATA)
926926

927927
def with_provenance_metadata(
928-
self, metadata: Union[ModelProvenanceMetadata, Dict]
928+
self, metadata: Union[ModelProvenanceMetadata, Dict]
929929
) -> "DataScienceModel":
930930
"""Sets model provenance metadata.
931931
@@ -1018,7 +1018,7 @@ def model_file_description(self) -> dict:
10181018
return self.get_spec(self.CONST_MODEL_FILE_DESCRIPTION)
10191019

10201020
def with_model_file_description(
1021-
self, json_dict: dict = None, json_string: str = None, json_uri: str = None
1021+
self, json_dict: dict = None, json_string: str = None, json_uri: str = None
10221022
):
10231023
"""Sets the json file description for model passed by reference
10241024
Parameters
@@ -1041,7 +1041,7 @@ def with_model_file_description(
10411041
elif json_string:
10421042
json_data = json.loads(json_string)
10431043
elif json_uri:
1044-
with open(json_uri) as json_file:
1044+
with open(json_uri, "r") as json_file:
10451045
json_data = json.load(json_file)
10461046
else:
10471047
raise ValueError("Must provide either a valid json string or URI location.")
@@ -1256,15 +1256,15 @@ def create(self, **kwargs) -> "DataScienceModel":
12561256
return self
12571257

12581258
def upload_artifact(
1259-
self,
1260-
bucket_uri: Optional[str] = None,
1261-
auth: Optional[Dict] = None,
1262-
region: Optional[str] = None,
1263-
overwrite_existing_artifact: Optional[bool] = True,
1264-
remove_existing_artifact: Optional[bool] = True,
1265-
timeout: Optional[int] = None,
1266-
parallel_process_count: int = utils.DEFAULT_PARALLEL_PROCESS_COUNT,
1267-
model_by_reference: Optional[bool] = False,
1259+
self,
1260+
bucket_uri: Optional[str] = None,
1261+
auth: Optional[Dict] = None,
1262+
region: Optional[str] = None,
1263+
overwrite_existing_artifact: Optional[bool] = True,
1264+
remove_existing_artifact: Optional[bool] = True,
1265+
timeout: Optional[int] = None,
1266+
parallel_process_count: int = utils.DEFAULT_PARALLEL_PROCESS_COUNT,
1267+
model_by_reference: Optional[bool] = False,
12681268
) -> None:
12691269
"""Uploads model artifacts to the model catalog.
12701270
@@ -1334,7 +1334,7 @@ def upload_artifact(
13341334
bucket_uri = self.artifact
13351335

13361336
if not model_by_reference and (
1337-
bucket_uri or utils.folder_size(self.artifact) > _MAX_ARTIFACT_SIZE_IN_BYTES
1337+
bucket_uri or utils.folder_size(self.artifact) > _MAX_ARTIFACT_SIZE_IN_BYTES
13381338
):
13391339
if not bucket_uri:
13401340
raise ModelArtifactSizeError(
@@ -1405,15 +1405,15 @@ def restore_model(
14051405
)
14061406

14071407
def download_artifact(
1408-
self,
1409-
target_dir: str,
1410-
auth: Optional[Dict] = None,
1411-
force_overwrite: Optional[bool] = False,
1412-
bucket_uri: Optional[str] = None,
1413-
region: Optional[str] = None,
1414-
overwrite_existing_artifact: Optional[bool] = True,
1415-
remove_existing_artifact: Optional[bool] = True,
1416-
timeout: Optional[int] = None,
1408+
self,
1409+
target_dir: str,
1410+
auth: Optional[Dict] = None,
1411+
force_overwrite: Optional[bool] = False,
1412+
bucket_uri: Optional[str] = None,
1413+
region: Optional[str] = None,
1414+
overwrite_existing_artifact: Optional[bool] = True,
1415+
remove_existing_artifact: Optional[bool] = True,
1416+
timeout: Optional[int] = None,
14171417
):
14181418
"""Downloads model artifacts from the model catalog.
14191419
@@ -1488,9 +1488,9 @@ def download_artifact(
14881488
)
14891489

14901490
if (
1491-
artifact_size > _MAX_ARTIFACT_SIZE_IN_BYTES
1492-
or bucket_uri
1493-
or model_by_reference
1491+
artifact_size > _MAX_ARTIFACT_SIZE_IN_BYTES
1492+
or bucket_uri
1493+
or model_by_reference
14941494
):
14951495
artifact_downloader = LargeArtifactDownloader(
14961496
dsc_model=self.dsc_model,
@@ -1536,22 +1536,21 @@ def update(self, **kwargs) -> "DataScienceModel":
15361536
self.dsc_model = self._to_oci_dsc_model(**kwargs).update()
15371537

15381538
logger.debug(f"Updating a model provenance metadata {self.provenance_metadata}")
1539-
if self.provenance_metadata:
1540-
try:
1541-
self.dsc_model.get_model_provenance()
1542-
self.dsc_model.update_model_provenance(
1543-
self.provenance_metadata._to_oci_metadata()
1544-
)
1545-
except ModelProvenanceNotFoundError:
1546-
self.dsc_model.create_model_provenance(
1547-
self.provenance_metadata._to_oci_metadata()
1548-
)
1539+
try:
1540+
self.dsc_model.get_model_provenance()
1541+
self.dsc_model.update_model_provenance(
1542+
self.provenance_metadata._to_oci_metadata()
1543+
)
1544+
except ModelProvenanceNotFoundError:
1545+
self.dsc_model.create_model_provenance(
1546+
self.provenance_metadata._to_oci_metadata()
1547+
)
15491548

15501549
return self.sync()
15511550

15521551
def delete(
1553-
self,
1554-
delete_associated_model_deployment: Optional[bool] = False,
1552+
self,
1553+
delete_associated_model_deployment: Optional[bool] = False,
15551554
) -> "DataScienceModel":
15561555
"""Removes model from the model catalog.
15571556
@@ -1570,7 +1569,7 @@ def delete(
15701569

15711570
@classmethod
15721571
def list(
1573-
cls, compartment_id: str = None, project_id: str = None, **kwargs
1572+
cls, compartment_id: str = None, project_id: str = None, **kwargs
15741573
) -> List["DataScienceModel"]:
15751574
"""Lists datascience models in a given compartment.
15761575
@@ -1597,7 +1596,7 @@ def list(
15971596

15981597
@classmethod
15991598
def list_df(
1600-
cls, compartment_id: str = None, project_id: str = None, **kwargs
1599+
cls, compartment_id: str = None, project_id: str = None, **kwargs
16011600
) -> "pandas.DataFrame":
16021601
"""Lists datascience models in a given compartment.
16031602
@@ -1617,7 +1616,7 @@ def list_df(
16171616
"""
16181617
records = []
16191618
for model in OCIDataScienceModel.list_resource(
1620-
compartment_id, project_id=project_id, **kwargs
1619+
compartment_id, project_id=project_id, **kwargs
16211620
):
16221621
records.append(
16231622
{
@@ -1660,8 +1659,6 @@ def _init_complex_attributes(self):
16601659
self.with_provenance_metadata(self.provenance_metadata)
16611660
self.with_input_schema(self.input_schema)
16621661
self.with_output_schema(self.output_schema)
1663-
# self.with_backup_setting(self.backup_setting)
1664-
# self.with_retention_setting(self.retention_setting)
16651662

16661663
def _to_oci_dsc_model(self, **kwargs):
16671664
"""Creates an `OCIDataScienceModel` instance from the `DataScienceModel`.
@@ -1700,7 +1697,7 @@ def _to_oci_dsc_model(self, **kwargs):
17001697
return OCIDataScienceModel(**dsc_spec)
17011698

17021699
def _update_from_oci_dsc_model(
1703-
self, dsc_model: OCIDataScienceModel
1700+
self, dsc_model: OCIDataScienceModel
17041701
) -> "DataScienceModel":
17051702
"""Update the properties from an OCIDataScienceModel object.
17061703
@@ -1973,12 +1970,12 @@ def _download_file_description_artifact(self) -> Tuple[Union[str, List[str]], in
19731970
return bucket_uri[0] if len(bucket_uri) == 1 else bucket_uri, artifact_size
19741971

19751972
def add_artifact(
1976-
self,
1977-
uri: Optional[str] = None,
1978-
namespace: Optional[str] = None,
1979-
bucket: Optional[str] = None,
1980-
prefix: Optional[str] = None,
1981-
files: Optional[List[str]] = None,
1973+
self,
1974+
uri: Optional[str] = None,
1975+
namespace: Optional[str] = None,
1976+
bucket: Optional[str] = None,
1977+
prefix: Optional[str] = None,
1978+
files: Optional[List[str]] = None,
19821979
):
19831980
"""
19841981
Adds information about objects in a specified bucket to the model description JSON.
@@ -2127,11 +2124,11 @@ def list_obj_versions_unpaginated():
21272124
self.set_spec(self.CONST_MODEL_FILE_DESCRIPTION, tmp_model_file_description)
21282125

21292126
def remove_artifact(
2130-
self,
2131-
uri: Optional[str] = None,
2132-
namespace: Optional[str] = None,
2133-
bucket: Optional[str] = None,
2134-
prefix: Optional[str] = None,
2127+
self,
2128+
uri: Optional[str] = None,
2129+
namespace: Optional[str] = None,
2130+
bucket: Optional[str] = None,
2131+
prefix: Optional[str] = None,
21352132
):
21362133
"""
21372134
Removes information about objects in a specified bucket or using a specified URI from the model description JSON.

0 commit comments

Comments
 (0)