Skip to content

Commit 35e8464

Browse files
committed
Reformatted using black.
1 parent 5f3b316 commit 35e8464

File tree

1 file changed

+79
-55
lines changed

1 file changed

+79
-55
lines changed

ads/model/model_description.py

Lines changed: 79 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -8,38 +8,45 @@
88
import ads
99
from ads.common import logger
1010

11+
1112
class ModelDescription:
1213

1314
empty_json = {
1415
"version": "1.0",
1516
"type": "modelOSSReferenceDescription",
16-
"models": []
17+
"models": [],
1718
}
18-
19-
def auth(self):
19+
20+
def auth(self):
2021
authData = ads.common.auth.default_signer()
21-
signer = authData['signer']
22-
self.region = authData['config']['region']
22+
signer = authData["signer"]
23+
self.region = authData["config"]["region"]
2324

2425
# data science client
25-
self.data_science_client = oci.data_science.DataScienceClient({'region': self.region}, signer=signer)
26+
self.data_science_client = oci.data_science.DataScienceClient(
27+
{"region": self.region}, signer=signer
28+
)
2629
# oss client
27-
self.object_storage_client = oci.object_storage.ObjectStorageClient({'region': self.region}, signer = signer)
28-
30+
self.object_storage_client = oci.object_storage.ObjectStorageClient(
31+
{"region": self.region}, signer=signer
32+
)
33+
2934
def __init__(self, model_ocid=None):
3035

31-
self.region = ''
36+
self.region = ""
3237
self.auth()
3338

34-
if model_ocid == None:
39+
if model_ocid == None:
3540
# if no model given then start from scratch
3641
self.modelDescriptionJson = self.empty_json
3742
else:
3843
# if model given then get that as the starting reference point
3944
logger.info("Getting model details from backend")
4045
try:
41-
get_model_artifact_content_response = self.data_science_client.get_model_artifact_content(
42-
model_id=model_ocid,
46+
get_model_artifact_content_response = (
47+
self.data_science_client.get_model_artifact_content(
48+
model_id=model_ocid,
49+
)
4350
)
4451
content = get_model_artifact_content_response.data.content
4552
self.modelDescriptionJson = json.loads(content)
@@ -57,16 +64,18 @@ def add(self, namespace, bucket, prefix=None, files=None):
5764
def checkIfFileExists(fileName):
5865
isExists = False
5966
try:
60-
headResponse = self.object_storage_client.head_object(namespace, bucket, object_name=fileName)
67+
headResponse = self.object_storage_client.head_object(
68+
namespace, bucket, object_name=fileName
69+
)
6170
if headResponse.status == 200:
6271
isExists = True
6372
except Exception as e:
64-
if hasattr(e, 'status') and e.status == 404:
73+
if hasattr(e, "status") and e.status == 404:
6574
logger.error(f"File not found in bucket: {fileName}")
6675
else:
6776
logger.error(f"An error occured: {e}")
6877
return isExists
69-
78+
7079
# Function to un-paginate the api call with while loop
7180
def listObjectVersionsUnpaginated():
7281
objectStorageList = []
@@ -77,8 +86,8 @@ def listObjectVersionsUnpaginated():
7786
bucket_name=bucket,
7887
prefix=prefix,
7988
fields="name,size",
80-
page = opc_next_page
81-
)
89+
page=opc_next_page,
90+
)
8291
objectStorageList.extend(response.data.items)
8392
has_next_page = response.has_next_page
8493
opc_next_page = response.next_page
@@ -91,38 +100,46 @@ def listObjectVersionsUnpaginated():
91100
else:
92101
for fileName in files:
93102
if checkIfFileExists(fileName=fileName):
94-
objectStorageList.append(self.object_storage_client.list_object_versions(
95-
namespace_name=namespace,
96-
bucket_name=bucket,
97-
prefix=fileName,
98-
fields="name,size",
99-
).data.items[0])
100-
101-
objects = [{
102-
"name": obj.name,
103-
"version": obj.version_id,
104-
"sizeInBytes": obj.size
105-
} for obj in objectStorageList if obj.size > 0]
106-
103+
objectStorageList.append(
104+
self.object_storage_client.list_object_versions(
105+
namespace_name=namespace,
106+
bucket_name=bucket,
107+
prefix=fileName,
108+
fields="name,size",
109+
).data.items[0]
110+
)
111+
112+
objects = [
113+
{"name": obj.name, "version": obj.version_id, "sizeInBytes": obj.size}
114+
for obj in objectStorageList
115+
if obj.size > 0
116+
]
117+
107118
if len(objects) == 0:
108119
error_message = (
109120
f"No files to add in the bucket: {bucket} with namespace: {namespace} "
110121
f"and prefix: {prefix}. File names: {files}"
111122
)
112123
logger.error(error_message)
113124
raise ValueError(error_message)
114-
115-
self.modelDescriptionJson['models'].append({
116-
"namespace": namespace,
117-
"bucketName": bucket,
118-
"prefix": prefix,
119-
"objects": objects
120-
})
121-
125+
126+
self.modelDescriptionJson["models"].append(
127+
{
128+
"namespace": namespace,
129+
"bucketName": bucket,
130+
"prefix": prefix,
131+
"objects": objects,
132+
}
133+
)
134+
122135
def remove(self, namespace, bucket, prefix=None):
123136
def findModelIdx():
124-
for idx, model in enumerate(self.modelDescriptionJson['models']):
125-
if (model['namespace'], model['bucketName'], (model['prefix'] if ('prefix' in model) else None) ) == (namespace, bucket, prefix):
137+
for idx, model in enumerate(self.modelDescriptionJson["models"]):
138+
if (
139+
model["namespace"],
140+
model["bucketName"],
141+
(model["prefix"] if ("prefix" in model) else None),
142+
) == (namespace, bucket, prefix):
126143
return idx
127144
return -1
128145

@@ -131,7 +148,7 @@ def findModelIdx():
131148
return
132149
else:
133150
# model found case
134-
self.modelDescriptionJson['models'].pop(modelSearchIdx)
151+
self.modelDescriptionJson["models"].pop(modelSearchIdx)
135152

136153
def show(self):
137154
logger.info(json.dumps(self.modelDescriptionJson, indent=4))
@@ -143,30 +160,37 @@ def build(self):
143160
with open(file_path, "w") as json_file:
144161
json.dump(self.modelDescriptionJson, json_file, indent=2)
145162
except IOError as e:
146-
logger.error(f"Error writing to file '{file_path}': {e}") # Handle the exception accordingly, e.g., log the error, retry writing, etc.
163+
logger.error(
164+
f"Error writing to file '{file_path}': {e}"
165+
) # Handle the exception accordingly, e.g., log the error, retry writing, etc.
147166
except Exception as e:
148-
logger.error(f"An unexpected error occurred: {e}") # Handle other unexpected exceptions
167+
logger.error(
168+
f"An unexpected error occurred: {e}"
169+
) # Handle other unexpected exceptions
149170
logger.info("Model Artifact stored at location: 'resultModelDescription.json'")
150171
return os.path.abspath(file_path)
151-
172+
152173
def save(self, project_ocid, compartment_ocid, display_name=None):
153-
display_name = 'Created by MMS SDK on ' + datetime.datetime.now(pytz.utc).strftime('%Y-%m-%d %H:%M:%S %Z') if (display_name == None) else display_name
154-
customMetadataList = [
155-
Metadata(key="modelDescription", value = "true")
156-
]
174+
display_name = (
175+
"Created by MMS SDK on "
176+
+ datetime.datetime.now(pytz.utc).strftime("%Y-%m-%d %H:%M:%S %Z")
177+
if (display_name == None)
178+
else display_name
179+
)
180+
customMetadataList = [Metadata(key="modelDescription", value="true")]
157181
model_details = oci.data_science.models.CreateModelDetails(
158-
compartment_id = compartment_ocid,
159-
project_id = project_ocid,
160-
display_name = display_name,
161-
custom_metadata_list = customMetadataList
182+
compartment_id=compartment_ocid,
183+
project_id=project_ocid,
184+
display_name=display_name,
185+
custom_metadata_list=customMetadataList,
162186
)
163187
logger.info("Created model details")
164188
model = self.data_science_client.create_model(model_details)
165189
logger.info("Created model")
166190
self.data_science_client.create_model_artifact(
167191
model.data.id,
168192
json.dumps(self.modelDescriptionJson),
169-
content_disposition='attachment; filename="modelDescription.json"'
193+
content_disposition='attachment; filename="modelDescription.json"',
170194
)
171-
logger.info('Successfully created model with OCID: ', model.data.id)
172-
return model.data.id
195+
logger.info("Successfully created model with OCID: ", model.data.id)
196+
return model.data.id

0 commit comments

Comments
 (0)