Skip to content

Commit ea0bfae

Browse files
authored
Improved model deployment progress bar (#144)
2 parents cd2c8f3 + 72a5dad commit ea0bfae

File tree

3 files changed

+107
-170
lines changed

3 files changed

+107
-170
lines changed

ads/common/oci_mixin.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import logging
1212
import os
1313
import re
14+
import time
1415
import traceback
1516
from datetime import date, datetime
1617
from typing import Callable, Optional, Union
@@ -20,7 +21,7 @@
2021
import yaml
2122
from ads.common import auth
2223
from ads.common.decorator.utils import class_or_instance_method
23-
from ads.common.utils import camel_to_snake
24+
from ads.common.utils import camel_to_snake, get_progress_bar
2425
from ads.config import COMPARTMENT_OCID
2526
from dateutil import tz
2627
from dateutil.parser import parse
@@ -29,6 +30,10 @@
2930
logger = logging.getLogger(__name__)
3031

3132
LIFECYCLE_STOP_STATE = ("SUCCEEDED", "FAILED", "CANCELED", "DELETED")
33+
WORK_REQUEST_STOP_STATE = ("SUCCEEDED", "FAILED", "CANCELED")
34+
DEFAULT_WAIT_TIME = 1200
35+
DEFAULT_POLL_INTERVAL = 10
36+
DEFAULT_WORKFLOW_STEPS = 2
3237

3338

3439
class MergeStrategy(Enum):
@@ -931,6 +936,76 @@ def get_work_request_response(
931936
)
932937
return work_request_response
933938

939+
def wait_for_progress(
940+
self,
941+
work_request_id: str,
942+
num_steps: int = DEFAULT_WORKFLOW_STEPS,
943+
max_wait_time: int = DEFAULT_WAIT_TIME,
944+
poll_interval: int = DEFAULT_POLL_INTERVAL
945+
):
946+
"""Waits for the work request progress bar to be completed.
947+
948+
Parameters
949+
----------
950+
work_request_id: str
951+
Work Request OCID.
952+
num_steps: (int, optional). Defaults to 2.
953+
Number of steps for the progress indicator.
954+
max_wait_time: int
955+
Maximum amount of time to wait in seconds (Defaults to 1200).
956+
Negative implies infinite wait time.
957+
poll_interval: int
958+
Poll interval in seconds (Defaults to 10).
959+
960+
Returns
961+
-------
962+
None
963+
"""
964+
work_request_logs = []
965+
966+
i = 0
967+
start_time = time.time()
968+
with get_progress_bar(num_steps) as progress:
969+
seconds_since = time.time() - start_time
970+
exceed_max_time = max_wait_time > 0 and seconds_since >= max_wait_time
971+
if exceed_max_time:
972+
logger.error(
973+
f"Max wait time ({max_wait_time} seconds) exceeded."
974+
)
975+
while not exceed_max_time and (not work_request_logs or len(work_request_logs) < num_steps):
976+
time.sleep(poll_interval)
977+
new_work_request_logs = []
978+
979+
try:
980+
work_request = self.client.get_work_request(work_request_id).data
981+
work_request_logs = self.client.list_work_request_logs(
982+
work_request_id
983+
).data
984+
except Exception as ex:
985+
logger.warn(ex)
986+
987+
new_work_request_logs = (
988+
work_request_logs[i:] if work_request_logs else []
989+
)
990+
991+
for wr_item in new_work_request_logs:
992+
progress.update(wr_item.message)
993+
i += 1
994+
995+
if work_request and work_request.status in WORK_REQUEST_STOP_STATE:
996+
if work_request.status != "SUCCEEDED":
997+
if new_work_request_logs:
998+
raise Exception(new_work_request_logs[-1].message)
999+
else:
1000+
raise Exception(
1001+
"Error occurred in attempt to perform the operation. "
1002+
"Check the service logs to get more details. "
1003+
f"{work_request}"
1004+
)
1005+
else:
1006+
break
1007+
progress.update("Done")
1008+
9341009

9351010
class OCIModelWithNameMixin:
9361011
"""Mixin class to operate OCI model which contains name property."""

ads/model/service/oci_datascience_model_deployment.py

Lines changed: 22 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,11 @@
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
from functools import wraps
8-
import json
9-
import time
108
import logging
119
from typing import Callable, List
1210
from ads.common.oci_datascience import OCIDataScienceMixin
13-
from ads.common import utils as progress_bar_utils
11+
from ads.common.oci_mixin import OCIWorkRequestMixin
1412
from ads.config import PROJECT_OCID
15-
from ads.model.deployment.common import utils
1613
from ads.model.deployment.common.utils import OCIClientManager, State
1714
import oci
1815

@@ -84,6 +81,7 @@ class MissingModelDeploymentWorkflowIdError(Exception): # pragma: no cover
8481

8582
class OCIDataScienceModelDeployment(
8683
OCIDataScienceMixin,
84+
OCIWorkRequestMixin,
8785
oci.data_science.models.ModelDeployment,
8886
):
8987
"""Represents an OCI Data Science Model Deployment.
@@ -188,20 +186,13 @@ def activate(
188186
if wait_for_completion:
189187

190188
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
191-
oci_model_deployment_object = self.client.get_model_deployment(self.id).data
192-
current_state = State._from_str(oci_model_deployment_object.lifecycle_state)
193-
model_deployment_id = self.id
194189

195190
try:
196-
self._wait_for_progress_completion(
197-
State.ACTIVE.name,
198-
ACTIVATE_WORKFLOW_STEPS,
199-
[State.FAILED.name, State.INACTIVE.name],
200-
self.workflow_req_id,
201-
current_state,
202-
model_deployment_id,
203-
max_wait_time,
204-
poll_interval,
191+
self.wait_for_progress(
192+
self.workflow_req_id,
193+
ACTIVATE_WORKFLOW_STEPS,
194+
max_wait_time,
195+
poll_interval
205196
)
206197
except Exception as e:
207198
logger.error(
@@ -243,20 +234,13 @@ def create(
243234
if wait_for_completion:
244235

245236
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
246-
res_payload = json.loads(str(response.data))
247-
current_state = State._from_str(res_payload["lifecycle_state"])
248-
model_deployment_id = self.id
249237

250238
try:
251-
self._wait_for_progress_completion(
252-
State.ACTIVE.name,
253-
CREATE_WORKFLOW_STEPS,
254-
[State.FAILED.name, State.INACTIVE.name],
255-
self.workflow_req_id,
256-
current_state,
257-
model_deployment_id,
258-
max_wait_time,
259-
poll_interval,
239+
self.wait_for_progress(
240+
self.workflow_req_id,
241+
CREATE_WORKFLOW_STEPS,
242+
max_wait_time,
243+
poll_interval
260244
)
261245
except Exception as e:
262246
logger.error(
@@ -301,20 +285,13 @@ def deactivate(
301285
if wait_for_completion:
302286

303287
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
304-
oci_model_deployment_object = self.client.get_model_deployment(self.id).data
305-
current_state = State._from_str(oci_model_deployment_object.lifecycle_state)
306-
model_deployment_id = self.id
307288

308289
try:
309-
self._wait_for_progress_completion(
310-
State.INACTIVE.name,
311-
DEACTIVATE_WORKFLOW_STEPS,
312-
[State.FAILED.name],
313-
self.workflow_req_id,
314-
current_state,
315-
model_deployment_id,
316-
max_wait_time,
317-
poll_interval,
290+
self.wait_for_progress(
291+
self.workflow_req_id,
292+
DEACTIVATE_WORKFLOW_STEPS,
293+
max_wait_time,
294+
poll_interval
318295
)
319296
except Exception as e:
320297
logger.error(
@@ -359,20 +336,13 @@ def delete(
359336
if wait_for_completion:
360337

361338
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
362-
oci_model_deployment_object = self.client.get_model_deployment(self.id).data
363-
current_state = State._from_str(oci_model_deployment_object.lifecycle_state)
364-
model_deployment_id = self.id
365339

366340
try:
367-
self._wait_for_progress_completion(
368-
State.DELETED.name,
369-
DELETE_WORKFLOW_STEPS,
370-
[State.FAILED.name, State.INACTIVE.name],
371-
self.workflow_req_id,
372-
current_state,
373-
model_deployment_id,
374-
max_wait_time,
375-
poll_interval,
341+
self.wait_for_progress(
342+
self.workflow_req_id,
343+
DELETE_WORKFLOW_STEPS,
344+
max_wait_time,
345+
poll_interval
376346
)
377347
except Exception as e:
378348
logger.error(
@@ -511,90 +481,3 @@ def from_id(cls, model_deployment_id: str) -> "OCIDataScienceModelDeployment":
511481
An instance of `OCIDataScienceModelDeployment`.
512482
"""
513483
return super().from_ocid(model_deployment_id)
514-
515-
def _wait_for_progress_completion(
516-
self,
517-
final_state: str,
518-
work_flow_step: int,
519-
disallowed_final_states: List[str],
520-
work_flow_request_id: str,
521-
state: State,
522-
model_deployment_id: str,
523-
max_wait_time: int = DEFAULT_WAIT_TIME,
524-
poll_interval: int = DEFAULT_POLL_INTERVAL,
525-
):
526-
"""_wait_for_progress_completion blocks until progress is completed.
527-
528-
Parameters
529-
----------
530-
final_state: str
531-
Final state of model deployment aimed to be reached.
532-
work_flow_step: int
533-
Number of work flow step of the request.
534-
disallowed_final_states: list[str]
535-
List of disallowed final state to be reached.
536-
work_flow_request_id: str
537-
The id of work flow request.
538-
state: State
539-
The current state of model deployment.
540-
model_deployment_id: str
541-
The ocid of model deployment.
542-
max_wait_time: int
543-
Maximum amount of time to wait in seconds (Defaults to 1200).
544-
Negative implies infinite wait time.
545-
poll_interval: int
546-
Poll interval in seconds (Defaults to 10).
547-
"""
548-
549-
start_time = time.time()
550-
prev_message = ""
551-
prev_workflow_stage_len = 0
552-
current_state = state or State.UNKNOWN
553-
with progress_bar_utils.get_progress_bar(work_flow_step) as progress:
554-
if max_wait_time > 0 and utils.seconds_since(start_time) >= max_wait_time:
555-
utils.get_logger().error(
556-
f"Max wait time ({max_wait_time} seconds) exceeded."
557-
)
558-
while (
559-
max_wait_time < 0 or utils.seconds_since(start_time) < max_wait_time
560-
) and current_state.name.upper() != final_state:
561-
if current_state.name.upper() in disallowed_final_states:
562-
utils.get_logger().info(
563-
f"Operation failed due to deployment reaching state {current_state.name.upper()}. Use Deployment ID for further steps."
564-
)
565-
break
566-
567-
prev_state = current_state.name
568-
try:
569-
model_deployment_payload = json.loads(
570-
str(self.client.get_model_deployment(model_deployment_id).data)
571-
)
572-
current_state = (
573-
State._from_str(model_deployment_payload["lifecycle_state"])
574-
if "lifecycle_state" in model_deployment_payload
575-
else State.UNKNOWN
576-
)
577-
workflow_payload = self.client.list_work_request_logs(
578-
work_flow_request_id
579-
).data
580-
if isinstance(workflow_payload, list) and len(workflow_payload) > 0:
581-
if prev_message != workflow_payload[-1].message:
582-
for _ in range(
583-
len(workflow_payload) - prev_workflow_stage_len
584-
):
585-
progress.update(workflow_payload[-1].message)
586-
prev_workflow_stage_len = len(workflow_payload)
587-
prev_message = workflow_payload[-1].message
588-
prev_workflow_stage_len = len(workflow_payload)
589-
if prev_state != current_state.name:
590-
utils.get_logger().info(
591-
f"Status Update: {current_state.name} in {utils.seconds_since(start_time)} seconds"
592-
)
593-
except Exception as e:
594-
# utils.get_logger().warning(
595-
# "Unable to update deployment status. Details: %s", format(
596-
# e)
597-
# )
598-
pass
599-
time.sleep(poll_interval)
600-
progress.update("Done")

0 commit comments

Comments
 (0)