1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8; -*-
3
+
4
+ # Copyright (c) 2024 Oracle and/or its affiliates.
5
+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
+
7
+ import logging
8
+ import sys
9
+ import time
10
+ from typing import Callable
11
+
12
+ import oci
13
+ from oci import Signer
14
+ from tqdm .auto import tqdm
15
+ from ads .common .oci_datascience import OCIDataScienceMixin
16
+
17
+ logger = logging .getLogger (__name__ )
18
+
19
+ WORK_REQUEST_STOP_STATE = ("SUCCEEDED" , "FAILED" , "CANCELED" )
20
+ DEFAULT_WAIT_TIME = 1200
21
+ DEFAULT_POLL_INTERVAL = 10
22
+ WORK_REQUEST_PERCENTAGE = 100
23
+ # default tqdm progress bar format:
24
+ # {l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, ' '{rate_fmt}{postfix}]
25
+ # customize the bar format to remove the {n_fmt}/{total_fmt} from the right side
26
+ DEFAULT_BAR_FORMAT = '{l_bar}{bar}| [{elapsed}<{remaining}, ' '{rate_fmt}{postfix}]'
27
+
28
+
29
+ class DataScienceWorkRequest (OCIDataScienceMixin ):
30
+ """Class for monitoring OCI WorkRequest and representing on tqdm progress bar. This class inherits
31
+ `OCIDataScienceMixin` so as to call its `client` attribute to interact with OCI backend.
32
+ """
33
+
34
+ def __init__ (
35
+ self ,
36
+ id : str ,
37
+ description : str = "Processing" ,
38
+ config : dict = None ,
39
+ signer : Signer = None ,
40
+ client_kwargs : dict = None ,
41
+ ** kwargs
42
+ ) -> None :
43
+ """Initializes ADSWorkRequest object.
44
+
45
+ Parameters
46
+ ----------
47
+ id: str
48
+ Work Request OCID.
49
+ description: str
50
+ Progress bar initial step description (Defaults to `Processing`).
51
+ config : dict, optional
52
+ OCI API key config dictionary to initialize
53
+ oci.data_science.DataScienceClient (Defaults to None).
54
+ signer : oci.signer.Signer, optional
55
+ OCI authentication signer to initialize
56
+ oci.data_science.DataScienceClient (Defaults to None).
57
+ client_kwargs : dict, optional
58
+ Additional client keyword arguments to initialize
59
+ oci.data_science.DataScienceClient (Defaults to None).
60
+ kwargs:
61
+ Additional keyword arguments to initialize
62
+ oci.data_science.DataScienceClient.
63
+ """
64
+ self .id = id
65
+ self ._description = description
66
+ self ._percentage = 0
67
+ self ._status = None
68
+ super ().__init__ (config , signer , client_kwargs , ** kwargs )
69
+
70
+
71
+ def _sync (self ):
72
+ """Fetches the latest work request information to ADSWorkRequest object."""
73
+ work_request = self .client .get_work_request (self .id ).data
74
+ work_request_logs = self .client .list_work_request_logs (
75
+ self .id
76
+ ).data
77
+
78
+ self ._percentage = work_request .percent_complete
79
+ self ._status = work_request .status
80
+ self ._description = work_request_logs [- 1 ].message if work_request_logs else "Processing"
81
+
82
+ def watch (
83
+ self ,
84
+ progress_callback : Callable ,
85
+ max_wait_time : int = DEFAULT_WAIT_TIME ,
86
+ poll_interval : int = DEFAULT_POLL_INTERVAL ,
87
+ ):
88
+ """Updates the progress bar with realtime message and percentage until the process is completed.
89
+
90
+ Parameters
91
+ ----------
92
+ progress_callback: Callable
93
+ Progress bar callback function.
94
+ It must accept `(percent_change, description)` where `percent_change` is the
95
+ work request percent complete and `description` is the latest work request log message.
96
+ max_wait_time: int
97
+ Maximum amount of time to wait in seconds (Defaults to 1200).
98
+ Negative implies infinite wait time.
99
+ poll_interval: int
100
+ Poll interval in seconds (Defaults to 10).
101
+
102
+ Returns
103
+ -------
104
+ None
105
+ """
106
+ previous_percent_complete = 0
107
+
108
+ start_time = time .time ()
109
+ while self ._percentage < 100 :
110
+
111
+ seconds_since = time .time () - start_time
112
+ if max_wait_time > 0 and seconds_since >= max_wait_time :
113
+ logger .error (f"Exceeded max wait time of { max_wait_time } seconds." )
114
+ return
115
+
116
+ time .sleep (poll_interval )
117
+
118
+ try :
119
+ self ._sync ()
120
+ except Exception as ex :
121
+ logger .warn (ex )
122
+ continue
123
+
124
+ percent_change = self ._percentage - previous_percent_complete
125
+ previous_percent_complete = self ._percentage
126
+ progress_callback (
127
+ percent_change = percent_change ,
128
+ description = self ._description
129
+ )
130
+
131
+ if self ._status in WORK_REQUEST_STOP_STATE :
132
+ if self ._status != oci .work_requests .models .WorkRequest .STATUS_SUCCEEDED :
133
+ if self ._description :
134
+ raise Exception (self ._description )
135
+ else :
136
+ raise Exception (
137
+ "Error occurred in attempt to perform the operation. "
138
+ "Check the service logs to get more details. "
139
+ f"Work request id: { self .id } ."
140
+ )
141
+ else :
142
+ break
143
+
144
+ progress_callback (percent_change = 0 , description = "Done" )
145
+
146
+ def wait_work_request (
147
+ self ,
148
+ progress_bar_description : str = "Processing" ,
149
+ max_wait_time : int = DEFAULT_WAIT_TIME ,
150
+ poll_interval : int = DEFAULT_POLL_INTERVAL
151
+ ):
152
+ """Waits for the work request progress bar to be completed.
153
+
154
+ Parameters
155
+ ----------
156
+ progress_bar_description: str
157
+ Progress bar initial step description (Defaults to `Processing`).
158
+ max_wait_time: int
159
+ Maximum amount of time to wait in seconds (Defaults to 1200).
160
+ Negative implies infinite wait time.
161
+ poll_interval: int
162
+ Poll interval in seconds (Defaults to 10).
163
+
164
+ Returns
165
+ -------
166
+ None
167
+ """
168
+
169
+ with tqdm (
170
+ total = WORK_REQUEST_PERCENTAGE ,
171
+ leave = False ,
172
+ mininterval = 0 ,
173
+ file = sys .stdout ,
174
+ desc = progress_bar_description ,
175
+ bar_format = DEFAULT_BAR_FORMAT
176
+ ) as pbar :
177
+
178
+ def progress_callback (percent_change , description ):
179
+ if percent_change != 0 :
180
+ pbar .update (percent_change )
181
+ if description :
182
+ pbar .set_description (description )
183
+
184
+ self .watch (
185
+ progress_callback = progress_callback ,
186
+ max_wait_time = max_wait_time ,
187
+ poll_interval = poll_interval
188
+ )
189
+
0 commit comments