Skip to content

Commit 651a0c0

Browse files
committed
updated changes as per comments
1 parent 4ce3fcf commit 651a0c0

File tree

5 files changed

+124
-159
lines changed

5 files changed

+124
-159
lines changed

ads/model/datascience_model.py

Lines changed: 40 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8; -*-
32

43
# Copyright (c) 2022, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
@@ -11,19 +10,22 @@
1110
import shutil
1211
import tempfile
1312
from copy import deepcopy
14-
from typing import Dict, List, Optional, Union, Tuple
13+
from typing import Dict, List, Optional, Tuple, Union
1514

1615
import pandas
1716
import yaml
1817
from jsonschema import ValidationError, validate
1918

19+
from ads.common import oci_client as oc
2020
from ads.common import utils
2121
from ads.common.extended_enum import ExtendedEnumMeta
2222
from ads.common.object_storage_details import ObjectStorageDetails
23+
from ads.config import (
24+
AQUA_SERVICE_MODELS_BUCKET as SERVICE_MODELS_BUCKET,
25+
)
2326
from ads.config import (
2427
COMPARTMENT_OCID,
2528
PROJECT_OCID,
26-
AQUA_SERVICE_MODELS_BUCKET as SERVICE_MODELS_BUCKET,
2729
)
2830
from ads.feature_engineering.schema import Schema
2931
from ads.jobs.builders.base import Builder
@@ -37,12 +39,12 @@
3739
ModelCustomMetadata,
3840
ModelCustomMetadataItem,
3941
ModelProvenanceMetadata,
40-
ModelTaxonomyMetadata, )
42+
ModelTaxonomyMetadata,
43+
)
4144
from ads.model.service.oci_datascience_model import (
4245
ModelProvenanceNotFoundError,
4346
OCIDataScienceModel,
4447
)
45-
from ads.common import oci_client as oc
4648

4749
logger = logging.getLogger(__name__)
4850

@@ -84,14 +86,6 @@ class CustomerNotificationType(str, metaclass=ExtendedEnumMeta):
8486
ON_FAILURE = "ON_FAILURE"
8587
ON_SUCCESS = "ON_SUCCESS"
8688

87-
@classmethod
88-
def is_valid(cls, value):
89-
return value in (cls.NONE, cls.ALL, cls.ON_FAILURE, cls.ON_SUCCESS)
90-
91-
@property
92-
def value(self):
93-
return str(self)
94-
9589

9690
class SettingStatus(str, metaclass=ExtendedEnumMeta):
9791
"""Enum to represent the status of retention settings."""
@@ -100,11 +94,6 @@ class SettingStatus(str, metaclass=ExtendedEnumMeta):
10094
SUCCEEDED = "SUCCEEDED"
10195
FAILED = "FAILED"
10296

103-
@classmethod
104-
def is_valid(cls, state: str) -> bool:
105-
"""Validates the given state against allowed SettingStatus values."""
106-
return state in (cls.PENDING, cls.SUCCEEDED, cls.FAILED)
107-
10897

10998
class ModelBackupSetting:
11099
"""
@@ -167,10 +156,7 @@ def to_json(self) -> str:
167156
@classmethod
168157
def from_json(cls, json_str) -> "ModelBackupSetting":
169158
"""Constructs backup settings from a JSON string or dictionary."""
170-
if isinstance(json_str, str):
171-
data = json.loads(json_str)
172-
else:
173-
data = json_str
159+
data = json.loads(json_str) if isinstance(json_str, str) else json_str
174160

175161
return cls.from_dict(data)
176162

@@ -180,14 +166,12 @@ def to_yaml(self) -> str:
180166

181167
def validate(self) -> bool:
182168
"""Validates the backup settings details. Returns True if valid, False otherwise."""
183-
if not isinstance(self.is_backup_enabled, bool):
184-
return False
185-
if self.backup_region and not isinstance(self.backup_region, str):
186-
return False
187-
if not isinstance(self.customer_notification_type, str) \
188-
or not CustomerNotificationType.is_valid(self.customer_notification_type):
189-
return False
190-
return True
169+
return all([
170+
isinstance(self.is_backup_enabled, bool),
171+
not self.backup_region or isinstance(self.backup_region, str),
172+
isinstance(self.customer_notification_type, str) and self.customer_notification_type in
173+
CustomerNotificationType.values()
174+
])
191175

192176
def __repr__(self):
193177
return self.to_yaml()
@@ -252,10 +236,7 @@ def to_json(self) -> str:
252236
@classmethod
253237
def from_json(cls, json_str) -> "ModelRetentionSetting":
254238
"""Constructs retention settings from a JSON string."""
255-
if isinstance(json_str, str):
256-
data = json.loads(json_str)
257-
else:
258-
data = json_str
239+
data = json.loads(json_str) if isinstance(json_str, str) else json_str
259240
return cls.from_dict(data)
260241

261242
def to_yaml(self) -> str:
@@ -264,18 +245,13 @@ def to_yaml(self) -> str:
264245

265246
def validate(self) -> bool:
266247
"""Validates the retention settings details. Returns True if valid, False otherwise."""
267-
if self.archive_after_days is not None and (
268-
not isinstance(self.archive_after_days, int) or self.archive_after_days < 0
269-
):
270-
return False
271-
if self.delete_after_days is not None and (
272-
not isinstance(self.delete_after_days, int) or self.delete_after_days < 0
273-
):
274-
return False
275-
if not isinstance(self.customer_notification_type, str) or not \
276-
CustomerNotificationType.is_valid(self.customer_notification_type):
277-
return False
278-
return True
248+
return all([
249+
self.archive_after_days is None or (
250+
isinstance(self.archive_after_days, int) and self.archive_after_days >= 0),
251+
self.delete_after_days is None or (isinstance(self.delete_after_days, int) and self.delete_after_days >= 0),
252+
isinstance(self.customer_notification_type, str) and self.customer_notification_type in
253+
CustomerNotificationType.values()
254+
])
279255

280256
def __repr__(self):
281257
return self.to_yaml()
@@ -358,8 +334,8 @@ def validate(self) -> bool:
358334
"""Validates the retention operation details."""
359335
return all(
360336
[
361-
self.archive_state is None or SettingStatus.is_valid(self.archive_state),
362-
self.delete_state is None or SettingStatus.is_valid(self.delete_state),
337+
self.archive_state is None or self.archive_state in SettingStatus.values(),
338+
self.delete_state is None or self.delete_state in SettingStatus.values(),
363339
self.time_archival_scheduled is None
364340
or isinstance(self.time_archival_scheduled, int),
365341
self.time_deletion_scheduled is None
@@ -395,18 +371,18 @@ def __init__(
395371
self,
396372
backup_state: Optional[SettingStatus] = None,
397373
backup_state_details: Optional[str] = None,
398-
time_last_backed_up: Optional[int] = None,
374+
time_last_backup: Optional[int] = None,
399375
):
400376
self.backup_state = backup_state
401377
self.backup_state_details = backup_state_details
402-
self.time_last_backed_up = time_last_backed_up
378+
self.time_last_backup = time_last_backup
403379

404380
def to_dict(self) -> Dict:
405381
"""Serializes the backup operation details into a dictionary."""
406382
return {
407383
"backup_state": self.backup_state or None,
408384
"backup_state_details": self.backup_state_details,
409-
"time_last_backed_up": self.time_last_backed_up,
385+
"time_last_backup": self.time_last_backup,
410386
}
411387

412388
@classmethod
@@ -415,7 +391,7 @@ def from_dict(cls, data: Dict) -> "ModelBackupOperationDetails":
415391
return cls(
416392
backup_state=SettingStatus(data.get("backup_state")) or None,
417393
backup_state_details=data.get("backup_state_details"),
418-
time_last_backed_up=data.get("time_last_backed_up"),
394+
time_last_backup=data.get("time_last_backup"),
419395
)
420396

421397
def to_json(self) -> str:
@@ -434,13 +410,10 @@ def to_yaml(self) -> str:
434410

435411
def validate(self) -> bool:
436412
"""Validates the backup operation details."""
437-
if self.backup_state is not None and not SettingStatus.is_valid(self.backup_state):
438-
return False
439-
if self.time_last_backed_up is not None and not isinstance(
440-
self.time_last_backed_up, int
441-
):
442-
return False
443-
return True
413+
return not (
414+
(self.backup_state is not None and not self.backup_state in SettingStatus.values()) or
415+
(self.time_last_backup is not None and not isinstance(self.time_last_backup, int))
416+
)
444417

445418
def __repr__(self):
446419
return self.to_yaml()
@@ -1068,7 +1041,7 @@ def with_model_file_description(
10681041
elif json_string:
10691042
json_data = json.loads(json_string)
10701043
elif json_uri:
1071-
with open(json_uri, "r") as json_file:
1044+
with open(json_uri) as json_file:
10721045
json_data = json.load(json_file)
10731046
else:
10741047
raise ValueError("Must provide either a valid json string or URI location.")
@@ -1423,17 +1396,11 @@ def restore_model(
14231396
return
14241397

14251398
# Optional: Validate restore_model_for_hours_specified
1426-
if restore_model_for_hours_specified is not None:
1427-
if (
1428-
not isinstance(restore_model_for_hours_specified, int)
1429-
or restore_model_for_hours_specified <= 0
1430-
):
1431-
raise ValueError(
1432-
"restore_model_for_hours_specified must be a positive integer."
1433-
)
1399+
if restore_model_for_hours_specified is not None and (
1400+
not isinstance(restore_model_for_hours_specified, int) or restore_model_for_hours_specified <= 0):
1401+
raise ValueError("restore_model_for_hours_specified must be a positive integer.")
14341402

14351403
self.dsc_model.restore_archived_model_artifact(
1436-
model_id=self.id,
14371404
restore_model_for_hours_specified=restore_model_for_hours_specified,
14381405
)
14391406

@@ -1692,8 +1659,8 @@ def _init_complex_attributes(self):
16921659
self.with_provenance_metadata(self.provenance_metadata)
16931660
self.with_input_schema(self.input_schema)
16941661
self.with_output_schema(self.output_schema)
1695-
self.with_backup_setting(self.backup_setting)
1696-
self.with_retention_setting(self.retention_setting)
1662+
# self.with_backup_setting(self.backup_setting)
1663+
# self.with_retention_setting(self.retention_setting)
16971664

16981665
def _to_oci_dsc_model(self, **kwargs):
16991666
"""Creates an `OCIDataScienceModel` instance from the `DataScienceModel`.
@@ -1753,6 +1720,8 @@ def _update_from_oci_dsc_model(
17531720
self.CONST_DEFINED_METADATA: ModelTaxonomyMetadata._from_oci_metadata,
17541721
self.CONST_BACKUP_SETTING: ModelBackupSetting.to_dict,
17551722
self.CONST_RETENTION_SETTING: ModelRetentionSetting.to_dict,
1723+
self.CONST_BACKUP_OPERATION_DETAILS: ModelBackupOperationDetails.to_dict,
1724+
self.CONST_RETENTION_OPERATION_DETAILS: ModelRetentionOperationDetails.to_dict
17561725
}
17571726

17581727
# Update the main properties

ads/model/generic_model.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,9 +1422,9 @@ def from_model_artifact(
14221422
)
14231423
model.update_summary_status(
14241424
detail=PREPARE_STATUS_POPULATE_METADATA_DETAIL,
1425-
status=(
1426-
ModelState.AVAILABLE.value if reload else ModelState.NOTAPPLICABLE.value
1427-
),
1425+
status=ModelState.AVAILABLE.value
1426+
if reload
1427+
else ModelState.NOTAPPLICABLE.value,
14281428
)
14291429

14301430
return model
@@ -1706,11 +1706,9 @@ def from_model_catalog(
17061706
)
17071707
result_model.update_summary_status(
17081708
detail=SAVE_STATUS_INTROSPECT_TEST_DETAIL,
1709-
status=(
1710-
ModelState.AVAILABLE.value
1711-
if not result_model.ignore_conda_error
1712-
else ModelState.NOTAVAILABLE.value
1713-
),
1709+
status=ModelState.AVAILABLE.value
1710+
if not result_model.ignore_conda_error
1711+
else ModelState.NOTAVAILABLE.value,
17141712
)
17151713
return result_model
17161714

0 commit comments

Comments
 (0)