1
+ import json
2
+ import ads .common
3
+ import oci
4
+ import os
5
+ import ads
6
+ from ads .model .datascience_model import DataScienceModel
7
+ from typing import List , Optional
8
+ import logging
9
+
10
+ logger = logging .getLogger ("ads.model_description" )
11
+ logger .setLevel (logging .INFO )
12
+
13
+
14
+ class DataScienceModelCollection (DataScienceModel ):
15
+
16
+ def _auth (self ):
17
+ """
18
+ Internal method that authenticates the model description instance by initializing OCI clients.
19
+
20
+ Parameters:
21
+ - None
22
+
23
+ Returns:
24
+ - None
25
+
26
+ Note:
27
+ - This method retrieves authentication data using default signer from the `ads.common.auth` module.
28
+ - The region information is extracted from the authentication data.
29
+ """
30
+ authData = ads .common .auth .default_signer ()
31
+ signer = authData ["signer" ]
32
+ self .region = authData ["config" ]["region" ]
33
+
34
+ # data science client
35
+ self .data_science_client = oci .data_science .DataScienceClient (
36
+ {"region" : self .region }, signer = signer
37
+ )
38
+ # oss client
39
+ self .object_storage_client = oci .object_storage .ObjectStorageClient (
40
+ {"region" : self .region }, signer = signer
41
+ )
42
+
43
+ def __init__ (self , spec : ads .Dict = None , ** kwargs ) -> None :
44
+ super ().__init__ (spec , ** kwargs )
45
+
46
+ self .empty_json = {
47
+ "version" : "1.0" ,
48
+ "type" : "modelOSSReferenceDescription" ,
49
+ "models" : [],
50
+ }
51
+ self .region = ""
52
+ self ._auth ()
53
+
54
+ self .set_spec (self .CONST_MODEL_FILE_DESCRIPTION , self .empty_json )
55
+
56
+ def with_ref_model_id (self , model_ocid : str ):
57
+
58
+ # if model given then get that as the starting reference point
59
+ logger .info ("Getting model details from backend" )
60
+ try :
61
+ get_model_artifact_content_response = (
62
+ self .data_science_client .get_model_artifact_content (
63
+ model_id = model_ocid ,
64
+ )
65
+ )
66
+ content = get_model_artifact_content_response .data .content
67
+ self .set_spec (self .CONST_MODEL_FILE_DESCRIPTION , json .loads (content ))
68
+ except json .JSONDecodeError as e :
69
+ logger .error (f"Error decoding JSON: { e } " )
70
+ raise e
71
+ except Exception as e :
72
+ logger .error (f"An unexpected error occurred: { e } " )
73
+ raise e
74
+ return self
75
+
76
+ def add (self , namespace : str , bucket : str , prefix : Optional [str ] = None , files : Optional [List [str ]] = None ):
77
+ """
78
+ Adds information about objects in a specified bucket to the model description JSON.
79
+
80
+ Parameters:
81
+ - namespace (str): The namespace of the object storage.
82
+ - bucket (str): The name of the bucket containing the objects.
83
+ - prefix (str, optional): The prefix used to filter objects within the bucket. Defaults to None.
84
+ - files (list of str, optional): A list of file names to include in the model description.
85
+ If provided, only objects with matching file names will be included. Defaults to None.
86
+
87
+ Returns:
88
+ - None
89
+
90
+ Raises:
91
+ - ValueError: If no files are found to add to the model description.
92
+
93
+ Note:
94
+ - If `files` is not provided, it retrieves information about all objects in the bucket.
95
+ If `files` is provided, it only retrieves information about objects with matching file names.
96
+ - If no objects are found to add to the model description, a ValueError is raised.
97
+ """
98
+
99
+ # Remove if the model already exists
100
+ self .remove (namespace , bucket , prefix )
101
+
102
+ def check_if_file_exists (fileName ):
103
+ isExists = False
104
+ try :
105
+ headResponse = self .object_storage_client .head_object (
106
+ namespace , bucket , object_name = fileName
107
+ )
108
+ if headResponse .status == 200 :
109
+ isExists = True
110
+ except Exception as e :
111
+ if hasattr (e , "status" ) and e .status == 404 :
112
+ logger .error (f"File not found in bucket: { fileName } " )
113
+ else :
114
+ logger .error (f"An error occured: { e } " )
115
+ return isExists
116
+
117
+ # Function to un-paginate the api call with while loop
118
+ def list_obj_versions_unpaginated ():
119
+ objectStorageList = []
120
+ has_next_page , opc_next_page = True , None
121
+ while has_next_page :
122
+ response = self .object_storage_client .list_object_versions (
123
+ namespace_name = namespace ,
124
+ bucket_name = bucket ,
125
+ prefix = prefix ,
126
+ fields = "name,size" ,
127
+ page = opc_next_page ,
128
+ )
129
+ objectStorageList .extend (response .data .items )
130
+ has_next_page = response .has_next_page
131
+ opc_next_page = response .next_page
132
+ return objectStorageList
133
+
134
+ # Fetch object details and put it into the objects variable
135
+ objectStorageList = []
136
+ if files == None :
137
+ objectStorageList = list_obj_versions_unpaginated ()
138
+ else :
139
+ for fileName in files :
140
+ if check_if_file_exists (fileName = fileName ):
141
+ objectStorageList .append (
142
+ self .object_storage_client .list_object_versions (
143
+ namespace_name = namespace ,
144
+ bucket_name = bucket ,
145
+ prefix = fileName ,
146
+ fields = "name,size" ,
147
+ ).data .items [0 ]
148
+ )
149
+
150
+ objects = [
151
+ {"name" : obj .name , "version" : obj .version_id , "sizeInBytes" : obj .size }
152
+ for obj in objectStorageList
153
+ if obj .size > 0
154
+ ]
155
+
156
+ if len (objects ) == 0 :
157
+ error_message = (
158
+ f"No files to add in the bucket: { bucket } with namespace: { namespace } "
159
+ f"and prefix: { prefix } . File names: { files } "
160
+ )
161
+ logger .error (error_message )
162
+ raise ValueError (error_message )
163
+
164
+ tmp_model_file_description = self .model_file_description
165
+ tmp_model_file_description ['models' ].append ({
166
+ "namespace" : namespace ,
167
+ "bucketName" : bucket ,
168
+ "prefix" : prefix ,
169
+ "objects" : objects ,
170
+ })
171
+ self .set_spec (self .CONST_MODEL_FILE_DESCRIPTION , tmp_model_file_description )
172
+
173
+ def remove (self , namespace : str , bucket : str , prefix : Optional [str ]= None ):
174
+ """
175
+ Removes information about objects in a specified bucket from the model description JSON.
176
+
177
+ Parameters:
178
+ - namespace (str): The namespace of the object storage.
179
+ - bucket (str): The name of the bucket containing the objects.
180
+ - prefix (str, optional): The prefix used to filter objects within the bucket. Defaults to None.
181
+
182
+ Returns:
183
+ - None
184
+
185
+ Note:
186
+ - This method removes information about objects in the specified bucket from the
187
+ instance of the ModelDescription.
188
+ - If a matching model (with the specified namespace, bucket name, and prefix) is found
189
+ in the model description JSON, it is removed.
190
+ - If no matching model is found, the method returns without making any changes.
191
+ """
192
+
193
+ def findModelIdx ():
194
+ for idx , model in enumerate (self .model_file_description ["models" ]):
195
+ if (
196
+ model ["namespace" ],
197
+ model ["bucketName" ],
198
+ (model ["prefix" ] if ("prefix" in model ) else None ),
199
+ ) == (namespace , bucket , prefix ):
200
+ return idx
201
+ return - 1
202
+
203
+ modelSearchIdx = findModelIdx ()
204
+ if modelSearchIdx == - 1 :
205
+ return
206
+ else :
207
+ # model found case
208
+ self .model_file_description ["models" ].pop (modelSearchIdx )
209
+
210
+ def create (self ):
211
+ """
212
+ Saves the model to the Model Catalog of Oracle Cloud Infrastructure (OCI) Data Science service.
213
+
214
+ Parameters:
215
+ - project_ocid (str): The OCID (Oracle Cloud Identifier) of the OCI Data Science project.
216
+ - compartment_ocid (str): The OCID of the compartment in which the model will be created.
217
+ - display_name (str, optional): The display name for the created model. If not provided,
218
+ a default display name indicating the creation timestamp is used. Defaults to None.
219
+
220
+ Returns:
221
+ - str: The OCID of the created model.
222
+
223
+ Note:
224
+ - The display name defaults to a string indicating the creation timestamp if not provided.
225
+ """
226
+ tmp_file_path = self .build ()
227
+ self = self .with_artifact (uri = tmp_file_path )
228
+ created_model = super ().create ()
229
+ try :
230
+ os .remove (tmp_file_path )
231
+ except Exception as e :
232
+ logger .error (f"Error occurred while cleaning file: { e } " )
233
+ raise e
234
+ return created_model .id
235
+
236
+ def build (self ) -> str :
237
+ """
238
+ Builds the model description JSON and writes it to a file.
239
+
240
+ Parameters:
241
+ - None
242
+
243
+ Returns:
244
+ - str: The absolute file path where the model description JSON is stored.
245
+
246
+ Note:
247
+ - This method serializes the current model description attribute to a JSON file named 'resultModelDescription.json' with an indentation of 2 spaces.
248
+ """
249
+ logger .info ("Building..." )
250
+ file_path = "resultModelDescription.json"
251
+ try :
252
+ with open (file_path , "w" ) as json_file :
253
+ json .dump (self .model_file_description , json_file , indent = 2 )
254
+ except IOError as e :
255
+ logger .error (
256
+ f"Error writing to file '{ file_path } ': { e } "
257
+ ) # Handle the exception accordingly, e.g., log the error, retry writing, etc.
258
+ except Exception as e :
259
+ logger .error (
260
+ f"An unexpected error occurred: { e } "
261
+ ) # Handle other unexpected exceptions
262
+ logger .info ("Model Artifact stored successfully." )
263
+ return os .path .abspath (file_path )
264
+
265
+ def show (self ):
266
+ """
267
+ Displays the current model description JSON in a human-readable format.
268
+
269
+ Parameters:
270
+ - None
271
+
272
+ Returns:
273
+ - str: The json representation of current model artifact
274
+
275
+ Note:
276
+ - The JSON representation of the model description is formatted with an indentation
277
+ of 4 spaces.
278
+ """
279
+ logger .info (json .dumps (self .model_file_description , indent = 4 ))
280
+ return json .dumps (self .model_file_description , indent = 4 )
0 commit comments