Skip to content

Commit d0309f4

Browse files
committed
Added uri based approach
1 parent c40ea8b commit d0309f4

File tree

1 file changed

+31
-6
lines changed

1 file changed

+31
-6
lines changed

ads/model/datascience_model.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
import cgi
8+
import re
89
import json
910
import logging
1011
import os
@@ -1471,9 +1472,7 @@ def _download_file_description_artifact(self) -> Tuple[Union[str, List[str]], in
14711472

14721473
def add_artifact(
14731474
self,
1474-
namespace: str,
1475-
bucket: str,
1476-
prefix: Optional[str] = None,
1475+
uri: str,
14771476
files: Optional[List[str]] = None,
14781477
):
14791478
"""
@@ -1497,6 +1496,9 @@ def add_artifact(
14971496
If `files` is provided, it only retrieves information about objects with matching file names.
14981497
- If no objects are found to add to the model description, a ValueError is raised.
14991498
"""
1499+
1500+
bucket, namespace, prefix = self._extract_oci_uri_components(uri)
1501+
15001502
if self.model_file_description == None:
15011503
self.empty_json = {
15021504
"version": "1.0",
@@ -1510,7 +1512,7 @@ def add_artifact(
15101512
self.object_storage_client = oc.OCIClientFactory(**authData).object_storage
15111513

15121514
# Remove if the model already exists
1513-
self.remove_artifact(namespace, bucket, prefix)
1515+
self.remove_artifact(uri=uri)
15141516

15151517
def check_if_file_exists(fileName):
15161518
isExists = False
@@ -1585,7 +1587,7 @@ def list_obj_versions_unpaginated():
15851587
)
15861588
self.set_spec(self.CONST_MODEL_FILE_DESCRIPTION, tmp_model_file_description)
15871589

1588-
def remove_artifact(self, namespace: str, bucket: str, prefix: Optional[str] = None):
1590+
def remove_artifact(self, uri: str):
15891591
"""
15901592
Removes information about objects in a specified bucket from the model description JSON.
15911593
@@ -1605,6 +1607,8 @@ def remove_artifact(self, namespace: str, bucket: str, prefix: Optional[str] = N
16051607
- If no matching model is found, the method returns without making any changes.
16061608
"""
16071609

1610+
bucket, namespace, prefix = self._extract_oci_uri_components(uri)
1611+
16081612
def findModelIdx():
16091613
for idx, model in enumerate(self.model_file_description["models"]):
16101614
if (
@@ -1623,4 +1627,25 @@ def findModelIdx():
16231627
return
16241628
else:
16251629
# model found case
1626-
self.model_file_description["models"].pop(modelSearchIdx)
1630+
self.model_file_description["models"].pop(modelSearchIdx)
1631+
1632+
def _extract_oci_uri_components(self, uri: str):
1633+
# Define the regular expression pattern to match the URI format
1634+
pattern = r"oci://(?P<bucket_name>[^@]+)@(?P<namespace>[^/]+)(?:/(?P<prefix>.*))?"
1635+
1636+
# Use re.match to apply the pattern to the URI
1637+
match = re.match(pattern, uri)
1638+
1639+
if match:
1640+
# Extract named groups using the groupdict() method
1641+
components = match.groupdict()
1642+
prefix = components.get('prefix', '')
1643+
# Treat a single trailing slash as no prefix
1644+
if prefix == "":
1645+
return components['bucket_name'], components['namespace'], ''
1646+
elif prefix == "/":
1647+
return components['bucket_name'], components['namespace'], ''
1648+
else:
1649+
return components['bucket_name'], components['namespace'], prefix
1650+
else:
1651+
raise ValueError("The URI format is incorrect")

0 commit comments

Comments
 (0)