Skip to content

Commit d4ef023

Browse files
committed
Added add_artifact and remove_artifact method in main DataScienceModel class itself.
1 parent 3d4d950 commit d4ef023

File tree

1 file changed

+158
-0
lines changed

1 file changed

+158
-0
lines changed

ads/model/datascience_model.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
ModelProvenanceNotFoundError,
4242
OCIDataScienceModel,
4343
)
44+
from ads.common import oci_client as oc
45+
from ads.common.auth import default_signer
4446

4547
logger = logging.getLogger(__name__)
4648

@@ -1466,3 +1468,159 @@ def _download_file_description_artifact(self) -> Tuple[Union[str, List[str]], in
14661468
bucket_uri.append(uri)
14671469

14681470
return bucket_uri[0] if len(bucket_uri) == 1 else bucket_uri, artifact_size
1471+
1472+
def add_artifact(
1473+
self,
1474+
namespace: str,
1475+
bucket: str,
1476+
prefix: Optional[str] = None,
1477+
files: Optional[List[str]] = None,
1478+
):
1479+
"""
1480+
Adds information about objects in a specified bucket to the model description JSON.
1481+
1482+
Parameters:
1483+
- namespace (str): The namespace of the object storage.
1484+
- bucket (str): The name of the bucket containing the objects.
1485+
- prefix (str, optional): The prefix used to filter objects within the bucket. Defaults to None.
1486+
- files (list of str, optional): A list of file names to include in the model description.
1487+
If provided, only objects with matching file names will be included. Defaults to None.
1488+
1489+
Returns:
1490+
- None
1491+
1492+
Raises:
1493+
- ValueError: If no files are found to add to the model description.
1494+
1495+
Note:
1496+
- If `files` is not provided, it retrieves information about all objects in the bucket.
1497+
If `files` is provided, it only retrieves information about objects with matching file names.
1498+
- If no objects are found to add to the model description, a ValueError is raised.
1499+
"""
1500+
if self.model_file_description == None:
1501+
self.empty_json = {
1502+
"version": "1.0",
1503+
"type": "modelOSSReferenceDescription",
1504+
"models": [],
1505+
}
1506+
self.set_spec(self.CONST_MODEL_FILE_DESCRIPTION, self.empty_json)
1507+
1508+
# Get object storage client
1509+
authData = default_signer()
1510+
self.object_storage_client = oc.OCIClientFactory(**authData).object_storage
1511+
1512+
# Remove if the model already exists
1513+
self.remove_artifact(namespace, bucket, prefix)
1514+
1515+
def check_if_file_exists(fileName):
1516+
isExists = False
1517+
try:
1518+
headResponse = self.object_storage_client.head_object(
1519+
namespace, bucket, object_name=fileName
1520+
)
1521+
if headResponse.status == 200:
1522+
isExists = True
1523+
except Exception as e:
1524+
if hasattr(e, "status") and e.status == 404:
1525+
logger.error(f"File not found in bucket: {fileName}")
1526+
else:
1527+
logger.error(f"An error occured: {e}")
1528+
return isExists
1529+
1530+
# Function to un-paginate the api call with while loop
1531+
def list_obj_versions_unpaginated():
1532+
objectStorageList = []
1533+
has_next_page, opc_next_page = True, None
1534+
while has_next_page:
1535+
response = self.object_storage_client.list_object_versions(
1536+
namespace_name=namespace,
1537+
bucket_name=bucket,
1538+
prefix=prefix,
1539+
fields="name,size",
1540+
page=opc_next_page,
1541+
)
1542+
objectStorageList.extend(response.data.items)
1543+
has_next_page = response.has_next_page
1544+
opc_next_page = response.next_page
1545+
return objectStorageList
1546+
1547+
# Fetch object details and put it into the objects variable
1548+
objectStorageList = []
1549+
if files == None:
1550+
objectStorageList = list_obj_versions_unpaginated()
1551+
else:
1552+
for fileName in files:
1553+
if check_if_file_exists(fileName=fileName):
1554+
objectStorageList.append(
1555+
self.object_storage_client.list_object_versions(
1556+
namespace_name=namespace,
1557+
bucket_name=bucket,
1558+
prefix=fileName,
1559+
fields="name,size",
1560+
).data.items[0]
1561+
)
1562+
1563+
objects = [
1564+
{"name": obj.name, "version": obj.version_id, "sizeInBytes": obj.size}
1565+
for obj in objectStorageList
1566+
if obj.size > 0
1567+
]
1568+
1569+
if len(objects) == 0:
1570+
error_message = (
1571+
f"No files to add in the bucket: {bucket} with namespace: {namespace} "
1572+
f"and prefix: {prefix}. File names: {files}"
1573+
)
1574+
logger.error(error_message)
1575+
raise ValueError(error_message)
1576+
1577+
tmp_model_file_description = self.model_file_description
1578+
tmp_model_file_description["models"].append(
1579+
{
1580+
"namespace": namespace,
1581+
"bucketName": bucket,
1582+
"prefix": prefix,
1583+
"objects": objects,
1584+
}
1585+
)
1586+
self.set_spec(self.CONST_MODEL_FILE_DESCRIPTION, tmp_model_file_description)
1587+
1588+
def remove_artifact(self, namespace: str, bucket: str, prefix: Optional[str] = None):
1589+
"""
1590+
Removes information about objects in a specified bucket from the model description JSON.
1591+
1592+
Parameters:
1593+
- namespace (str): The namespace of the object storage.
1594+
- bucket (str): The name of the bucket containing the objects.
1595+
- prefix (str, optional): The prefix used to filter objects within the bucket. Defaults to None.
1596+
1597+
Returns:
1598+
- None
1599+
1600+
Note:
1601+
- This method removes information about objects in the specified bucket from the
1602+
instance of the ModelDescription.
1603+
- If a matching model (with the specified namespace, bucket name, and prefix) is found
1604+
in the model description JSON, it is removed.
1605+
- If no matching model is found, the method returns without making any changes.
1606+
"""
1607+
1608+
def findModelIdx():
1609+
for idx, model in enumerate(self.model_file_description["models"]):
1610+
if (
1611+
model["namespace"],
1612+
model["bucketName"],
1613+
(model["prefix"] if ("prefix" in model) else None),
1614+
) == (namespace, bucket, prefix):
1615+
return idx
1616+
return -1
1617+
1618+
if self.model_file_description == None:
1619+
return
1620+
1621+
modelSearchIdx = findModelIdx()
1622+
if modelSearchIdx == -1:
1623+
return
1624+
else:
1625+
# model found case
1626+
self.model_file_description["models"].pop(modelSearchIdx)

0 commit comments

Comments
 (0)