Skip to content

Commit ecc5347

Browse files
committed
Updated pr and added unit tests.
1 parent 7e8bd6b commit ecc5347

File tree

2 files changed

+137
-6
lines changed

2 files changed

+137
-6
lines changed

ads/common/work_request.py

Lines changed: 71 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828

2929
class ADSWorkRequest(OCIDataScienceMixin):
30+
"""Class for monitoring OCI WorkRequest and representing on tqdm progress bar. This class inherits
31+
`OCIDataScienceMixin` so as to call its `client` attribute to interact with OCI backend.
32+
"""
3033

3134
def __init__(
3235
self,
@@ -37,6 +40,27 @@ def __init__(
3740
client_kwargs: dict = None,
3841
**kwargs
3942
) -> None:
43+
"""Initializes ADSWorkRequest object.
44+
45+
Parameters
46+
----------
47+
id: str
48+
Work Request OCID.
49+
description: str
50+
Progress bar initial step description (Defaults to `Processing`).
51+
config : dict, optional
52+
OCI API key config dictionary to initialize
53+
oci.data_science.DataScienceClient (Defaults to None).
54+
signer : oci.signer.Signer, optional
55+
OCI authentication signer to initialize
56+
oci.data_science.DataScienceClient (Defaults to None).
57+
client_kwargs : dict, optional
58+
Additional client keyword arguments to initialize
59+
oci.data_science.DataScienceClient (Defaults to None).
60+
kwargs:
61+
Additional keyword arguments to initialize
62+
oci.data_science.DataScienceClient.
63+
"""
4064
self.id = id
4165
self._description = description
4266
self._percentage = 0
@@ -45,6 +69,7 @@ def __init__(
4569

4670

4771
def _sync(self):
72+
"""Fetches the latest work request information to ADSWorkRequest object."""
4873
work_request = self.client.get_work_request(self.id).data
4974
work_request_logs = self.client.list_work_request_logs(
5075
self.id
@@ -57,17 +82,35 @@ def _sync(self):
5782
def watch(
5883
self,
5984
progress_callback: Callable,
60-
max_wait_time: int,
61-
poll_interval: int,
85+
max_wait_time: int=DEFAULT_WAIT_TIME,
86+
poll_interval: int=DEFAULT_POLL_INTERVAL,
6287
):
88+
"""Updates the progress bar with realtime message and percentage until the process is completed.
89+
90+
Parameters
91+
----------
92+
progress_callback: Callable
93+
Progress bar callback function.
94+
It must accept `(percent_change, description)` where `percent_change` is the
95+
work request percent complete and `description` is the latest work request log message.
96+
max_wait_time: int
97+
Maximum amount of time to wait in seconds (Defaults to 1200).
98+
Negative implies infinite wait time.
99+
poll_interval: int
100+
Poll interval in seconds (Defaults to 10).
101+
102+
Returns
103+
-------
104+
None
105+
"""
63106
previous_percent_complete = 0
64107

65108
start_time = time.time()
66109
while self._percentage < 100:
67110

68111
seconds_since = time.time() - start_time
69112
if max_wait_time > 0 and seconds_since >= max_wait_time:
70-
logger.error(f"Max wait time ({max_wait_time} seconds) exceeded.")
113+
logger.error(f"Exceeded max wait time of {max_wait_time} seconds.")
71114
return
72115

73116
time.sleep(poll_interval)
@@ -80,7 +123,10 @@ def watch(
80123

81124
percent_change = self._percentage - previous_percent_complete
82125
previous_percent_complete = self._percentage
83-
progress_callback(percent_change, self._description)
126+
progress_callback(
127+
progress=percent_change,
128+
description=self._description
129+
)
84130

85131
if self._status in WORK_REQUEST_STOP_STATE:
86132
if self._status != oci.work_requests.models.WorkRequest.STATUS_SUCCEEDED:
@@ -90,19 +136,38 @@ def watch(
90136
raise Exception(
91137
"Error occurred in attempt to perform the operation. "
92138
"Check the service logs to get more details. "
139+
f"Work request id: {self.id}."
93140
)
94141
else:
95142
break
96143

97-
progress_callback(0, "Done")
144+
progress_callback(progress=0, description="Done")
98145

99146

100147
def wait_work_request(
101148
id: str,
102-
progress_bar_description: str,
149+
progress_bar_description: str="Processing",
103150
max_wait_time: int=DEFAULT_WAIT_TIME,
104151
poll_interval: int=DEFAULT_POLL_INTERVAL
105152
):
153+
"""Waits for the work request progress bar to be completed.
154+
155+
Parameters
156+
----------
157+
id: str
158+
Work Request OCID.
159+
progress_bar_description: str
160+
Progress bar initial step description (Defaults to `Processing`).
161+
max_wait_time: int
162+
Maximum amount of time to wait in seconds (Defaults to 1200).
163+
Negative implies infinite wait time.
164+
poll_interval: int
165+
Poll interval in seconds (Defaults to 10).
166+
167+
Returns
168+
-------
169+
None
170+
"""
106171
ads_work_request = ADSWorkRequest(id)
107172

108173
with tqdm(
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2023 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
6+
import pytest
7+
from unittest.mock import MagicMock, patch
8+
from ads.common.work_request import ADSWorkRequest
9+
10+
11+
class TestADSWorkRequest:
12+
13+
@patch("ads.common.work_request.ADSWorkRequest._sync")
14+
@patch("ads.common.oci_datascience.OCIDataScienceMixin.__init__")
15+
def test_watch_succeed(self, mock_oci_datascience, mock_sync):
16+
ads_work_request = ADSWorkRequest(
17+
id="test_id",
18+
description = "Processing"
19+
)
20+
ads_work_request._percentage = 90
21+
ads_work_request._status = "SUCCEEDED"
22+
ads_work_request.watch(
23+
progress_callback=MagicMock(),
24+
poll_interval=0
25+
)
26+
mock_oci_datascience.assert_called()
27+
mock_sync.assert_called()
28+
29+
@patch("ads.common.work_request.ADSWorkRequest._sync")
30+
@patch("ads.common.oci_datascience.OCIDataScienceMixin.__init__")
31+
def test_watch_failed_with_description(self, mock_oci_datascience, mock_sync):
32+
ads_work_request = ADSWorkRequest(
33+
id="test_id",
34+
description = "Backend Error"
35+
)
36+
ads_work_request._percentage = 30
37+
ads_work_request._status = "FAILED"
38+
with pytest.raises(Exception, match="Backend Error"):
39+
ads_work_request.watch(
40+
progress_callback=MagicMock(),
41+
poll_interval=0
42+
)
43+
mock_oci_datascience.assert_called()
44+
mock_sync.assert_called()
45+
46+
@patch("ads.common.work_request.ADSWorkRequest._sync")
47+
@patch("ads.common.oci_datascience.OCIDataScienceMixin.__init__")
48+
def test_watch_failed_without_description(self, mock_oci_datascience, mock_sync):
49+
ads_work_request = ADSWorkRequest(
50+
id="test_id",
51+
description = None
52+
)
53+
ads_work_request._percentage = 30
54+
ads_work_request._status = "FAILED"
55+
with pytest.raises(
56+
Exception,
57+
match="Error occurred in attempt to perform the operation. "
58+
"Check the service logs to get more details. "
59+
f"Work request id: {ads_work_request.id}"
60+
):
61+
ads_work_request.watch(
62+
progress_callback=MagicMock(),
63+
poll_interval=0
64+
)
65+
mock_oci_datascience.assert_called()
66+
mock_sync.assert_called()

0 commit comments

Comments
 (0)