Skip to content

Commit d2d7182

Browse files
committed
Improved progress bar.
1 parent 18a9fef commit d2d7182

File tree

1 file changed

+94
-0
lines changed

1 file changed

+94
-0
lines changed

ads/common/oci_mixin.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
import logging
1212
import os
1313
import re
14+
import sys
1415
import time
1516
import traceback
1617
from datetime import date, datetime
1718
from typing import Callable, Optional, Union
1819
from enum import Enum
1920

2021
import oci
22+
import tqdm
2123
import yaml
2224
from ads.common import auth
2325
from ads.common.decorator.utils import class_or_instance_method
@@ -1038,3 +1040,95 @@ def from_name(cls, name: str, compartment_id: Optional[str] = None):
10381040
if not res:
10391041
raise OCIModelNotExists()
10401042
return cls.from_oci_model(res[0])
1043+
1044+
1045+
class ADSWorkRequest(OCIClientMixin):
1046+
1047+
def __init__(self, id: str, description: str = "Processing"):
1048+
self.id = id
1049+
self._description = description
1050+
self._percentage = 0
1051+
self._status = None
1052+
1053+
def _sync(self):
1054+
try:
1055+
work_request = self.client.get_work_request(self.id).data
1056+
work_request_logs = self.client.list_work_request_logs(
1057+
self.id
1058+
).data
1059+
1060+
self._percentage= work_request.percent_complete
1061+
self._status = work_request.status
1062+
self._description = work_request_logs[:-1]
1063+
except Exception as ex:
1064+
logger.warn(ex)
1065+
1066+
def watch(
1067+
self,
1068+
progress_callback: Callable,
1069+
max_wait_time: int,
1070+
poll_interval: int,
1071+
):
1072+
previous_percent_complete = 0
1073+
previous_log = None
1074+
1075+
start_time = time.time()
1076+
while self._percentage < 100:
1077+
1078+
seconds_since = time.time() - start_time
1079+
if max_wait_time > 0 and seconds_since >= max_wait_time:
1080+
logger.error(f"Max wait time ({max_wait_time} seconds) exceeded.")
1081+
return
1082+
1083+
time.sleep(poll_interval)
1084+
self._sync()
1085+
percent_change = self._percentage - previous_percent_complete
1086+
previous_percent_complete = self._percentage
1087+
description = self._description if previous_log != self._description else ""
1088+
progress_callback(
1089+
percent_change=percent_change,
1090+
description=description
1091+
)
1092+
previous_log = self._description
1093+
1094+
if self._status in WORK_REQUEST_STOP_STATE:
1095+
if self._status != oci.work_requests.models.WorkRequest.STATUS_SUCCEEDED:
1096+
if self._description:
1097+
raise Exception(self._description)
1098+
else:
1099+
raise Exception(
1100+
"Error occurred in attempt to perform the operation. "
1101+
"Check the service logs to get more details. "
1102+
)
1103+
else:
1104+
break
1105+
1106+
progress_callback(percent_change=0, description="Done")
1107+
1108+
1109+
def wait_work_request(
1110+
id: str,
1111+
desc: str,
1112+
max_wait_time: int=DEFAULT_WAIT_TIME,
1113+
poll_interval: int=DEFAULT_POLL_INTERVAL
1114+
):
1115+
ads_work_request = ADSWorkRequest(id)
1116+
1117+
with tqdm(
1118+
leave=False,
1119+
file=sys.stdout,
1120+
desc=desc,
1121+
) as pbar:
1122+
1123+
def progress_callback(percent_change, description):
1124+
if percent_change != 0:
1125+
pbar.update(percent_change)
1126+
if description:
1127+
pbar.set_description(description)
1128+
1129+
ads_work_request.watch(
1130+
progress_callback,
1131+
max_wait_time=max_wait_time,
1132+
poll_interval=poll_interval
1133+
)
1134+

0 commit comments

Comments
 (0)