Skip to content

Commit 8e24e15

Browse files
committed
Updated pr.
1 parent efd32bf commit 8e24e15

File tree

4 files changed

+131
-167
lines changed

4 files changed

+131
-167
lines changed

ads/common/oci_mixin.py

Lines changed: 0 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,96 +1040,3 @@ def from_name(cls, name: str, compartment_id: Optional[str] = None):
10401040
if not res:
10411041
raise OCIModelNotExists()
10421042
return cls.from_oci_model(res[0])
1043-
1044-
1045-
class ADSWorkRequest(OCIClientMixin):
1046-
1047-
def __init__(self, id: str, description: str = "Processing"):
1048-
self.id = id
1049-
self._description = description
1050-
self._percentage = 0
1051-
self._status = None
1052-
1053-
def _sync(self):
1054-
work_request = self.client.get_work_request(self.id).data
1055-
work_request_logs = self.client.list_work_request_logs(
1056-
self.id
1057-
).data
1058-
1059-
self._percentage= work_request.percent_complete
1060-
self._status = work_request.status
1061-
self._description = work_request_logs[:-1]
1062-
1063-
def watch(
1064-
self,
1065-
progress_callback: Callable,
1066-
max_wait_time: int,
1067-
poll_interval: int,
1068-
):
1069-
previous_percent_complete = 0
1070-
previous_log = None
1071-
1072-
start_time = time.time()
1073-
while self._percentage < 100:
1074-
1075-
seconds_since = time.time() - start_time
1076-
if max_wait_time > 0 and seconds_since >= max_wait_time:
1077-
logger.error(f"Max wait time ({max_wait_time} seconds) exceeded.")
1078-
return
1079-
1080-
time.sleep(poll_interval)
1081-
1082-
try:
1083-
self._sync()
1084-
except Exception as ex:
1085-
logger.warn(ex)
1086-
continue
1087-
1088-
percent_change = self._percentage - previous_percent_complete
1089-
previous_percent_complete = self._percentage
1090-
description = self._description if previous_log != self._description else ""
1091-
progress_callback(percent_change, description)
1092-
previous_log = self._description
1093-
1094-
if self._status in WORK_REQUEST_STOP_STATE:
1095-
if self._status != oci.work_requests.models.WorkRequest.STATUS_SUCCEEDED:
1096-
if self._description:
1097-
raise Exception(self._description)
1098-
else:
1099-
raise Exception(
1100-
"Error occurred in attempt to perform the operation. "
1101-
"Check the service logs to get more details. "
1102-
)
1103-
else:
1104-
break
1105-
1106-
progress_callback(0, "Done")
1107-
1108-
1109-
def wait_work_request(
1110-
id: str,
1111-
progress_bar_description: str,
1112-
max_wait_time: int=DEFAULT_WAIT_TIME,
1113-
poll_interval: int=DEFAULT_POLL_INTERVAL
1114-
):
1115-
ads_work_request = ADSWorkRequest(id)
1116-
1117-
with tqdm(
1118-
leave=False,
1119-
mininterval=0,
1120-
file=sys.stdout,
1121-
desc=progress_bar_description,
1122-
) as pbar:
1123-
1124-
def progress_callback(percent_change, description):
1125-
if percent_change != 0:
1126-
pbar.update(percent_change)
1127-
if description:
1128-
pbar.set_description(description)
1129-
1130-
ads_work_request.watch(
1131-
progress_callback,
1132-
max_wait_time,
1133-
poll_interval
1134-
)
1135-

ads/common/work_request.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
2+
import logging
3+
import sys
4+
import time
5+
from typing import Callable
6+
7+
import oci
8+
from oci import Signer
9+
from tqdm.auto import tqdm
10+
from ads.common.oci_datascience import OCIDataScienceMixin
11+
12+
logger = logging.getLogger(__name__)
13+
14+
WORK_REQUEST_STOP_STATE = ("SUCCEEDED", "FAILED", "CANCELED")
15+
DEFAULT_WAIT_TIME = 1200
16+
DEFAULT_POLL_INTERVAL = 10
17+
WORK_REQUEST_PERCENTAGE = 100
18+
# default tqdm progress bar format:
19+
# {l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, ' '{rate_fmt}{postfix}]
20+
# customize the bar format to remove the {n_fmt}/{total_fmt} from the right side
21+
DEFAULT_BAR_FORMAT = '{l_bar}{bar}| [{elapsed}<{remaining}, ' '{rate_fmt}{postfix}]'
22+
23+
24+
class ADSWorkRequest(OCIDataScienceMixin):
25+
26+
def __init__(
27+
self,
28+
id: str,
29+
description: str = "Processing",
30+
config: dict = None,
31+
signer: Signer = None,
32+
client_kwargs: dict = None,
33+
**kwargs
34+
) -> None:
35+
self.id = id
36+
self._description = description
37+
self._percentage = 0
38+
self._status = None
39+
super().__init__(config, signer, client_kwargs, **kwargs)
40+
41+
42+
def _sync(self):
43+
work_request = self.client.get_work_request(self.id).data
44+
work_request_logs = self.client.list_work_request_logs(
45+
self.id
46+
).data
47+
48+
self._percentage= work_request.percent_complete
49+
self._status = work_request.status
50+
self._description = work_request_logs[-1].message if work_request_logs else "Processing"
51+
52+
def watch(
53+
self,
54+
progress_callback: Callable,
55+
max_wait_time: int,
56+
poll_interval: int,
57+
):
58+
previous_percent_complete = 0
59+
60+
start_time = time.time()
61+
while self._percentage < 100:
62+
63+
seconds_since = time.time() - start_time
64+
if max_wait_time > 0 and seconds_since >= max_wait_time:
65+
logger.error(f"Max wait time ({max_wait_time} seconds) exceeded.")
66+
return
67+
68+
time.sleep(poll_interval)
69+
70+
try:
71+
self._sync()
72+
except Exception as ex:
73+
logger.warn(ex)
74+
continue
75+
76+
percent_change = self._percentage - previous_percent_complete
77+
previous_percent_complete = self._percentage
78+
progress_callback(percent_change, self._description)
79+
80+
if self._status in WORK_REQUEST_STOP_STATE:
81+
if self._status != oci.work_requests.models.WorkRequest.STATUS_SUCCEEDED:
82+
if self._description:
83+
raise Exception(self._description)
84+
else:
85+
raise Exception(
86+
"Error occurred in attempt to perform the operation. "
87+
"Check the service logs to get more details. "
88+
)
89+
else:
90+
break
91+
92+
progress_callback(0, "Done")
93+
94+
95+
def wait_work_request(
96+
id: str,
97+
progress_bar_description: str,
98+
max_wait_time: int=DEFAULT_WAIT_TIME,
99+
poll_interval: int=DEFAULT_POLL_INTERVAL
100+
):
101+
ads_work_request = ADSWorkRequest(id)
102+
103+
with tqdm(
104+
total=WORK_REQUEST_PERCENTAGE,
105+
leave=False,
106+
mininterval=0,
107+
file=sys.stdout,
108+
desc=progress_bar_description,
109+
bar_format=DEFAULT_BAR_FORMAT
110+
) as pbar:
111+
112+
def progress_callback(percent_change, description):
113+
if percent_change != 0:
114+
pbar.update(percent_change)
115+
if description:
116+
pbar.set_description(description)
117+
118+
ads_work_request.watch(
119+
progress_callback,
120+
max_wait_time,
121+
poll_interval
122+
)
123+

ads/model/service/oci_datascience_model.py

Lines changed: 3 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from ads.common.oci_mixin import OCIWorkRequestMixin
1818
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
1919
from ads.common.utils import extract_region
20+
from ads.common.work_request import wait_work_request
2021
from ads.model.deployment import ModelDeployment
2122
from oci.data_science.models import (
2223
ArtifactExportDetailsObjectStorage,
@@ -361,10 +362,7 @@ def import_model_artifact(self, bucket_uri: str, region: str = None) -> None:
361362
).headers["opc-work-request-id"]
362363

363364
# Show progress of importing artifacts
364-
self._wait_for_work_request(
365-
work_request_id=work_request_id,
366-
num_steps=2,
367-
)
365+
wait_work_request(work_request_id)
368366
except ServiceError as ex:
369367
if ex.status == 404:
370368
raise ModelArtifactNotFoundError()
@@ -408,10 +406,7 @@ def export_model_artifact(self, bucket_uri: str, region: str = None):
408406
).headers["opc-work-request-id"]
409407

410408
# Show progress of exporting model artifacts
411-
self._wait_for_work_request(
412-
work_request_id=work_request_id,
413-
num_steps=2,
414-
)
409+
wait_work_request(work_request_id)
415410

416411
@check_for_model_id(
417412
msg="Model needs to be saved to the Model Catalog before it can be updated."
@@ -540,63 +535,3 @@ def from_id(cls, ocid: str) -> "OCIDataScienceModel":
540535
if not ocid:
541536
raise ValueError("Model OCID not provided.")
542537
return super().from_ocid(ocid)
543-
544-
def _wait_for_work_request(self, work_request_id: str, num_steps: int = 3) -> None:
545-
"""Waits for the work request to be completed.
546-
547-
Parameters
548-
----------
549-
work_request_id: str
550-
Work Request OCID.
551-
num_steps: (int, optional). Defaults to 3.
552-
Number of steps for the progress indicator.
553-
554-
Returns
555-
-------
556-
None
557-
"""
558-
STOP_STATE = (
559-
WorkRequest.STATUS_SUCCEEDED,
560-
WorkRequest.STATUS_CANCELED,
561-
WorkRequest.STATUS_FAILED,
562-
)
563-
work_request_logs = []
564-
565-
i = 0
566-
with utils.get_progress_bar(num_steps) as progress:
567-
while not work_request_logs or len(work_request_logs) < num_steps:
568-
time.sleep(_REQUEST_INTERVAL_IN_SEC)
569-
new_work_request_logs = []
570-
571-
try:
572-
work_request = self.client.get_work_request(work_request_id).data
573-
work_request_logs = self.client.list_work_request_logs(
574-
work_request_id
575-
).data
576-
except Exception as ex:
577-
logger.warn(ex)
578-
579-
new_work_request_logs = (
580-
work_request_logs[i:] if work_request_logs else []
581-
)
582-
583-
for wr_item in new_work_request_logs:
584-
progress.update(wr_item.message)
585-
i += 1
586-
587-
if work_request and work_request.status in STOP_STATE:
588-
if work_request.status != WorkRequest.STATUS_SUCCEEDED:
589-
if new_work_request_logs:
590-
raise Exception(new_work_request_logs[-1].message)
591-
else:
592-
raise Exception(
593-
"Error occurred in attempt to perform the operation. "
594-
"Check the service logs to get more details. "
595-
f"{work_request}"
596-
)
597-
else:
598-
break
599-
600-
while i < num_steps:
601-
progress.update()
602-
i += 1

ads/model/service/oci_datascience_model_deployment.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import logging
99
from typing import Callable, List
1010
from ads.common.oci_datascience import OCIDataScienceMixin
11-
from ads.common.oci_mixin import OCIWorkRequestMixin, wait_work_request
11+
from ads.common.work_request import wait_work_request
1212
from ads.config import PROJECT_OCID
1313
from ads.model.deployment.common.utils import OCIClientManager, State
1414
import oci
@@ -77,7 +77,6 @@ class MissingModelDeploymentWorkflowIdError(Exception): # pragma: no cover
7777

7878
class OCIDataScienceModelDeployment(
7979
OCIDataScienceMixin,
80-
OCIWorkRequestMixin,
8180
oci.data_science.models.ModelDeployment,
8281
):
8382
"""Represents an OCI Data Science Model Deployment.
@@ -192,7 +191,7 @@ def activate(
192191
try:
193192
wait_work_request(
194193
self.workflow_req_id,
195-
f"Activating model deployment: {self.display_name}",
194+
"Activating model deployment",
196195
max_wait_time,
197196
poll_interval
198197
)
@@ -243,7 +242,7 @@ def create(
243242
try:
244243
wait_work_request(
245244
self.workflow_req_id,
246-
f"Creating model deployment: {self.display_name}",
245+
"Creating model deployment",
247246
max_wait_time,
248247
poll_interval
249248
)
@@ -297,7 +296,7 @@ def deactivate(
297296
try:
298297
wait_work_request(
299298
self.workflow_req_id,
300-
f"Deactivating model deployment: {self.display_name}",
299+
"Deactivating model deployment",
301300
max_wait_time,
302301
poll_interval
303302
)
@@ -366,7 +365,7 @@ def delete(
366365
try:
367366
wait_work_request(
368367
self.workflow_req_id,
369-
f"Deleting model deployment: {self.display_name}",
368+
"Deleting model deployment",
370369
max_wait_time,
371370
poll_interval
372371
)

0 commit comments

Comments
 (0)