Skip to content

Commit d35b917

Browse files
committed
Added class DataScienceModelCollection that extends from DataScienceModel
1 parent 848972e commit d35b917

File tree

1 file changed

+280
-0
lines changed

1 file changed

+280
-0
lines changed

ads/model/model_collection.py

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
import json
2+
import ads.common
3+
import oci
4+
import os
5+
import ads
6+
from ads.model.datascience_model import DataScienceModel
7+
from typing import List, Optional
8+
import logging
9+
10+
logger = logging.getLogger("ads.model_description")
11+
logger.setLevel(logging.INFO)
12+
13+
14+
class DataScienceModelCollection(DataScienceModel):
15+
16+
def _auth(self):
17+
"""
18+
Internal method that authenticates the model description instance by initializing OCI clients.
19+
20+
Parameters:
21+
- None
22+
23+
Returns:
24+
- None
25+
26+
Note:
27+
- This method retrieves authentication data using default signer from the `ads.common.auth` module.
28+
- The region information is extracted from the authentication data.
29+
"""
30+
authData = ads.common.auth.default_signer()
31+
signer = authData["signer"]
32+
self.region = authData["config"]["region"]
33+
34+
# data science client
35+
self.data_science_client = oci.data_science.DataScienceClient(
36+
{"region": self.region}, signer=signer
37+
)
38+
# oss client
39+
self.object_storage_client = oci.object_storage.ObjectStorageClient(
40+
{"region": self.region}, signer=signer
41+
)
42+
43+
def __init__(self, spec: ads.Dict = None, **kwargs) -> None:
44+
super().__init__(spec, **kwargs)
45+
46+
self.empty_json = {
47+
"version": "1.0",
48+
"type": "modelOSSReferenceDescription",
49+
"models": [],
50+
}
51+
self.region = ""
52+
self._auth()
53+
54+
self.set_spec(self.CONST_MODEL_FILE_DESCRIPTION, self.empty_json)
55+
56+
def with_ref_model_id(self, model_ocid: str):
57+
58+
# if model given then get that as the starting reference point
59+
logger.info("Getting model details from backend")
60+
try:
61+
get_model_artifact_content_response = (
62+
self.data_science_client.get_model_artifact_content(
63+
model_id=model_ocid,
64+
)
65+
)
66+
content = get_model_artifact_content_response.data.content
67+
self.set_spec(self.CONST_MODEL_FILE_DESCRIPTION, json.loads(content))
68+
except json.JSONDecodeError as e:
69+
logger.error(f"Error decoding JSON: {e}")
70+
raise e
71+
except Exception as e:
72+
logger.error(f"An unexpected error occurred: {e}")
73+
raise e
74+
return self
75+
76+
def add(self, namespace: str, bucket: str, prefix: Optional[str] =None, files: Optional[List[str]] =None):
77+
"""
78+
Adds information about objects in a specified bucket to the model description JSON.
79+
80+
Parameters:
81+
- namespace (str): The namespace of the object storage.
82+
- bucket (str): The name of the bucket containing the objects.
83+
- prefix (str, optional): The prefix used to filter objects within the bucket. Defaults to None.
84+
- files (list of str, optional): A list of file names to include in the model description.
85+
If provided, only objects with matching file names will be included. Defaults to None.
86+
87+
Returns:
88+
- None
89+
90+
Raises:
91+
- ValueError: If no files are found to add to the model description.
92+
93+
Note:
94+
- If `files` is not provided, it retrieves information about all objects in the bucket.
95+
If `files` is provided, it only retrieves information about objects with matching file names.
96+
- If no objects are found to add to the model description, a ValueError is raised.
97+
"""
98+
99+
# Remove if the model already exists
100+
self.remove(namespace, bucket, prefix)
101+
102+
def check_if_file_exists(fileName):
103+
isExists = False
104+
try:
105+
headResponse = self.object_storage_client.head_object(
106+
namespace, bucket, object_name=fileName
107+
)
108+
if headResponse.status == 200:
109+
isExists = True
110+
except Exception as e:
111+
if hasattr(e, "status") and e.status == 404:
112+
logger.error(f"File not found in bucket: {fileName}")
113+
else:
114+
logger.error(f"An error occured: {e}")
115+
return isExists
116+
117+
# Function to un-paginate the api call with while loop
118+
def list_obj_versions_unpaginated():
119+
objectStorageList = []
120+
has_next_page, opc_next_page = True, None
121+
while has_next_page:
122+
response = self.object_storage_client.list_object_versions(
123+
namespace_name=namespace,
124+
bucket_name=bucket,
125+
prefix=prefix,
126+
fields="name,size",
127+
page=opc_next_page,
128+
)
129+
objectStorageList.extend(response.data.items)
130+
has_next_page = response.has_next_page
131+
opc_next_page = response.next_page
132+
return objectStorageList
133+
134+
# Fetch object details and put it into the objects variable
135+
objectStorageList = []
136+
if files == None:
137+
objectStorageList = list_obj_versions_unpaginated()
138+
else:
139+
for fileName in files:
140+
if check_if_file_exists(fileName=fileName):
141+
objectStorageList.append(
142+
self.object_storage_client.list_object_versions(
143+
namespace_name=namespace,
144+
bucket_name=bucket,
145+
prefix=fileName,
146+
fields="name,size",
147+
).data.items[0]
148+
)
149+
150+
objects = [
151+
{"name": obj.name, "version": obj.version_id, "sizeInBytes": obj.size}
152+
for obj in objectStorageList
153+
if obj.size > 0
154+
]
155+
156+
if len(objects) == 0:
157+
error_message = (
158+
f"No files to add in the bucket: {bucket} with namespace: {namespace} "
159+
f"and prefix: {prefix}. File names: {files}"
160+
)
161+
logger.error(error_message)
162+
raise ValueError(error_message)
163+
164+
tmp_model_file_description = self.model_file_description
165+
tmp_model_file_description['models'].append({
166+
"namespace": namespace,
167+
"bucketName": bucket,
168+
"prefix": prefix,
169+
"objects": objects,
170+
})
171+
self.set_spec(self.CONST_MODEL_FILE_DESCRIPTION, tmp_model_file_description)
172+
173+
def remove(self, namespace: str, bucket: str, prefix: Optional[str]=None):
174+
"""
175+
Removes information about objects in a specified bucket from the model description JSON.
176+
177+
Parameters:
178+
- namespace (str): The namespace of the object storage.
179+
- bucket (str): The name of the bucket containing the objects.
180+
- prefix (str, optional): The prefix used to filter objects within the bucket. Defaults to None.
181+
182+
Returns:
183+
- None
184+
185+
Note:
186+
- This method removes information about objects in the specified bucket from the
187+
instance of the ModelDescription.
188+
- If a matching model (with the specified namespace, bucket name, and prefix) is found
189+
in the model description JSON, it is removed.
190+
- If no matching model is found, the method returns without making any changes.
191+
"""
192+
193+
def findModelIdx():
194+
for idx, model in enumerate(self.model_file_description["models"]):
195+
if (
196+
model["namespace"],
197+
model["bucketName"],
198+
(model["prefix"] if ("prefix" in model) else None),
199+
) == (namespace, bucket, prefix):
200+
return idx
201+
return -1
202+
203+
modelSearchIdx = findModelIdx()
204+
if modelSearchIdx == -1:
205+
return
206+
else:
207+
# model found case
208+
self.model_file_description["models"].pop(modelSearchIdx)
209+
210+
def create(self):
211+
"""
212+
Saves the model to the Model Catalog of Oracle Cloud Infrastructure (OCI) Data Science service.
213+
214+
Parameters:
215+
- project_ocid (str): The OCID (Oracle Cloud Identifier) of the OCI Data Science project.
216+
- compartment_ocid (str): The OCID of the compartment in which the model will be created.
217+
- display_name (str, optional): The display name for the created model. If not provided,
218+
a default display name indicating the creation timestamp is used. Defaults to None.
219+
220+
Returns:
221+
- str: The OCID of the created model.
222+
223+
Note:
224+
- The display name defaults to a string indicating the creation timestamp if not provided.
225+
"""
226+
tmp_file_path = self.build()
227+
self = self.with_artifact(uri=tmp_file_path)
228+
created_model = super().create()
229+
try:
230+
os.remove(tmp_file_path)
231+
except Exception as e:
232+
logger.error(f"Error occurred while cleaning file: {e}")
233+
raise e
234+
return created_model.id
235+
236+
def build(self) -> str:
237+
"""
238+
Builds the model description JSON and writes it to a file.
239+
240+
Parameters:
241+
- None
242+
243+
Returns:
244+
- str: The absolute file path where the model description JSON is stored.
245+
246+
Note:
247+
- This method serializes the current model description attribute to a JSON file named 'resultModelDescription.json' with an indentation of 2 spaces.
248+
"""
249+
logger.info("Building...")
250+
file_path = "resultModelDescription.json"
251+
try:
252+
with open(file_path, "w") as json_file:
253+
json.dump(self.model_file_description, json_file, indent=2)
254+
except IOError as e:
255+
logger.error(
256+
f"Error writing to file '{file_path}': {e}"
257+
) # Handle the exception accordingly, e.g., log the error, retry writing, etc.
258+
except Exception as e:
259+
logger.error(
260+
f"An unexpected error occurred: {e}"
261+
) # Handle other unexpected exceptions
262+
logger.info("Model Artifact stored successfully.")
263+
return os.path.abspath(file_path)
264+
265+
def show(self):
266+
"""
267+
Displays the current model description JSON in a human-readable format.
268+
269+
Parameters:
270+
- None
271+
272+
Returns:
273+
- str: The json representation of current model artifact
274+
275+
Note:
276+
- The JSON representation of the model description is formatted with an indentation
277+
of 4 spaces.
278+
"""
279+
logger.info(json.dumps(self.model_file_description, indent=4))
280+
return json.dumps(self.model_file_description, indent=4)

0 commit comments

Comments
 (0)