Skip to content

Commit c66bd7d

Browse files
committed
Improved model deployment progress bar
1 parent 230cde6 commit c66bd7d

File tree

1 file changed

+78
-119
lines changed

1 file changed

+78
-119
lines changed

ads/model/service/oci_datascience_model_deployment.py

Lines changed: 78 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
from functools import wraps
8-
import json
98
import time
109
import logging
1110
from typing import Callable, List
@@ -19,6 +18,7 @@
1918
from oci.data_science.models import (
2019
CreateModelDeploymentDetails,
2120
UpdateModelDeploymentDetails,
21+
WorkRequest
2222
)
2323

2424
DEFAULT_WAIT_TIME = 1200
@@ -188,20 +188,13 @@ def activate(
188188
if wait_for_completion:
189189

190190
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
191-
oci_model_deployment_object = self.client.get_model_deployment(self.id).data
192-
current_state = State._from_str(oci_model_deployment_object.lifecycle_state)
193-
model_deployment_id = self.id
194191

195192
try:
196-
self._wait_for_progress_completion(
197-
State.ACTIVE.name,
198-
ACTIVATE_WORKFLOW_STEPS,
199-
[State.FAILED.name, State.INACTIVE.name],
200-
self.workflow_req_id,
201-
current_state,
202-
model_deployment_id,
203-
max_wait_time,
204-
poll_interval,
193+
self._wait_for_work_request(
194+
self.workflow_req_id,
195+
ACTIVATE_WORKFLOW_STEPS,
196+
max_wait_time,
197+
poll_interval
205198
)
206199
except Exception as e:
207200
logger.error(
@@ -243,20 +236,13 @@ def create(
243236
if wait_for_completion:
244237

245238
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
246-
res_payload = json.loads(str(response.data))
247-
current_state = State._from_str(res_payload["lifecycle_state"])
248-
model_deployment_id = self.id
249239

250240
try:
251-
self._wait_for_progress_completion(
252-
State.ACTIVE.name,
253-
CREATE_WORKFLOW_STEPS,
254-
[State.FAILED.name, State.INACTIVE.name],
255-
self.workflow_req_id,
256-
current_state,
257-
model_deployment_id,
258-
max_wait_time,
259-
poll_interval,
241+
self._wait_for_work_request(
242+
self.workflow_req_id,
243+
CREATE_WORKFLOW_STEPS,
244+
max_wait_time,
245+
poll_interval
260246
)
261247
except Exception as e:
262248
logger.error(
@@ -301,20 +287,13 @@ def deactivate(
301287
if wait_for_completion:
302288

303289
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
304-
oci_model_deployment_object = self.client.get_model_deployment(self.id).data
305-
current_state = State._from_str(oci_model_deployment_object.lifecycle_state)
306-
model_deployment_id = self.id
307290

308291
try:
309-
self._wait_for_progress_completion(
310-
State.INACTIVE.name,
311-
DEACTIVATE_WORKFLOW_STEPS,
312-
[State.FAILED.name],
313-
self.workflow_req_id,
314-
current_state,
315-
model_deployment_id,
316-
max_wait_time,
317-
poll_interval,
292+
self._wait_for_work_request(
293+
self.workflow_req_id,
294+
DEACTIVATE_WORKFLOW_STEPS,
295+
max_wait_time,
296+
poll_interval
318297
)
319298
except Exception as e:
320299
logger.error(
@@ -359,20 +338,13 @@ def delete(
359338
if wait_for_completion:
360339

361340
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
362-
oci_model_deployment_object = self.client.get_model_deployment(self.id).data
363-
current_state = State._from_str(oci_model_deployment_object.lifecycle_state)
364-
model_deployment_id = self.id
365341

366342
try:
367-
self._wait_for_progress_completion(
368-
State.DELETED.name,
369-
DELETE_WORKFLOW_STEPS,
370-
[State.FAILED.name, State.INACTIVE.name],
371-
self.workflow_req_id,
372-
current_state,
373-
model_deployment_id,
374-
max_wait_time,
375-
poll_interval,
343+
self._wait_for_work_request(
344+
self.workflow_req_id,
345+
DELETE_WORKFLOW_STEPS,
346+
max_wait_time,
347+
poll_interval
376348
)
377349
except Exception as e:
378350
logger.error(
@@ -512,89 +484,76 @@ def from_id(cls, model_deployment_id: str) -> "OCIDataScienceModelDeployment":
512484
"""
513485
return super().from_ocid(model_deployment_id)
514486

515-
def _wait_for_progress_completion(
516-
self,
517-
final_state: str,
518-
work_flow_step: int,
519-
disallowed_final_states: List[str],
520-
work_flow_request_id: str,
521-
state: State,
522-
model_deployment_id: str,
523-
max_wait_time: int = DEFAULT_WAIT_TIME,
524-
poll_interval: int = DEFAULT_POLL_INTERVAL,
525-
):
526-
"""_wait_for_progress_completion blocks until progress is completed.
487+
def _wait_for_work_request(
488+
self,
489+
work_request_id: str,
490+
num_steps: int = DELETE_WORKFLOW_STEPS,
491+
max_wait_time: int = DEFAULT_WAIT_TIME,
492+
poll_interval: int = DEFAULT_POLL_INTERVAL
493+
) -> None:
494+
"""Waits for the work request to be completed.
527495
528496
Parameters
529497
----------
530-
final_state: str
531-
Final state of model deployment aimed to be reached.
532-
work_flow_step: int
533-
Number of work flow step of the request.
534-
disallowed_final_states: list[str]
535-
List of disallowed final state to be reached.
536-
work_flow_request_id: str
537-
The id of work flow request.
538-
state: State
539-
The current state of model deployment.
540-
model_deployment_id: str
541-
The ocid of model deployment.
498+
work_request_id: str
499+
Work Request OCID.
500+
num_steps: (int, optional). Defaults to 6.
501+
Number of steps for the progress indicator.
542502
max_wait_time: int
543503
Maximum amount of time to wait in seconds (Defaults to 1200).
544504
Negative implies infinite wait time.
545505
poll_interval: int
546506
Poll interval in seconds (Defaults to 10).
507+
508+
Returns
509+
-------
510+
None
547511
"""
512+
STOP_STATE = (
513+
WorkRequest.STATUS_SUCCEEDED,
514+
WorkRequest.STATUS_CANCELED,
515+
WorkRequest.STATUS_FAILED,
516+
)
517+
work_request_logs = []
548518

519+
i = 0
549520
start_time = time.time()
550-
prev_message = ""
551-
prev_workflow_stage_len = 0
552-
current_state = state or State.UNKNOWN
553-
with progress_bar_utils.get_progress_bar(work_flow_step) as progress:
554-
if max_wait_time > 0 and utils.seconds_since(start_time) >= max_wait_time:
555-
utils.get_logger().error(
521+
with progress_bar_utils.get_progress_bar(num_steps) as progress:
522+
exceed_max_time = max_wait_time > 0 and utils.seconds_since(start_time) >= max_wait_time
523+
if exceed_max_time:
524+
logger.error(
556525
f"Max wait time ({max_wait_time} seconds) exceeded."
557526
)
558-
while (
559-
max_wait_time < 0 or utils.seconds_since(start_time) < max_wait_time
560-
) and current_state.name.upper() != final_state:
561-
if current_state.name.upper() in disallowed_final_states:
562-
utils.get_logger().info(
563-
f"Operation failed due to deployment reaching state {current_state.name.upper()}. Use Deployment ID for further steps."
564-
)
565-
break
566-
567-
prev_state = current_state.name
527+
while not exceed_max_time and (not work_request_logs or len(work_request_logs) < num_steps):
528+
time.sleep(poll_interval)
529+
new_work_request_logs = []
530+
568531
try:
569-
model_deployment_payload = json.loads(
570-
str(self.client.get_model_deployment(model_deployment_id).data)
571-
)
572-
current_state = (
573-
State._from_str(model_deployment_payload["lifecycle_state"])
574-
if "lifecycle_state" in model_deployment_payload
575-
else State.UNKNOWN
576-
)
577-
workflow_payload = self.client.list_work_request_logs(
578-
work_flow_request_id
532+
work_request = self.client.get_work_request(work_request_id).data
533+
work_request_logs = self.client.list_work_request_logs(
534+
work_request_id
579535
).data
580-
if isinstance(workflow_payload, list) and len(workflow_payload) > 0:
581-
if prev_message != workflow_payload[-1].message:
582-
for _ in range(
583-
len(workflow_payload) - prev_workflow_stage_len
584-
):
585-
progress.update(workflow_payload[-1].message)
586-
prev_workflow_stage_len = len(workflow_payload)
587-
prev_message = workflow_payload[-1].message
588-
prev_workflow_stage_len = len(workflow_payload)
589-
if prev_state != current_state.name:
590-
utils.get_logger().info(
591-
f"Status Update: {current_state.name} in {utils.seconds_since(start_time)} seconds"
592-
)
593-
except Exception as e:
594-
# utils.get_logger().warning(
595-
# "Unable to update deployment status. Details: %s", format(
596-
# e)
597-
# )
598-
pass
599-
time.sleep(poll_interval)
536+
except Exception as ex:
537+
logger.warn(ex)
538+
539+
new_work_request_logs = (
540+
work_request_logs[i:] if work_request_logs else []
541+
)
542+
543+
for wr_item in new_work_request_logs:
544+
progress.update(wr_item.message)
545+
i += 1
546+
547+
if work_request and work_request.status in STOP_STATE:
548+
if work_request.status != WorkRequest.STATUS_SUCCEEDED:
549+
if new_work_request_logs:
550+
raise Exception(new_work_request_logs[-1].message)
551+
else:
552+
raise Exception(
553+
"Error occurred in attempt to perform the operation. "
554+
"Check the service logs to get more details. "
555+
f"{work_request}"
556+
)
557+
else:
558+
break
600559
progress.update("Done")

0 commit comments

Comments
 (0)