Skip to content

Commit 358c6fa

Browse files
committed
Updated pr.
1 parent cde9aee commit 358c6fa

File tree

2 files changed

+152
-90
lines changed

2 files changed

+152
-90
lines changed

ads/model/service/oci_datascience_model_deployment.py

Lines changed: 122 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
from functools import wraps
8+
import json
89
import time
910
import logging
1011
from typing import Callable, List
@@ -18,7 +19,6 @@
1819
from oci.data_science.models import (
1920
CreateModelDeploymentDetails,
2021
UpdateModelDeploymentDetails,
21-
WorkRequest
2222
)
2323

2424
DEFAULT_WAIT_TIME = 1200
@@ -74,11 +74,11 @@ def wrapper(self, *args, **kwargs):
7474
return decorator
7575

7676

77-
class MissingModelDeploymentIdError(Exception):
77+
class MissingModelDeploymentIdError(Exception): # pragma: no cover
7878
pass
7979

8080

81-
class MissingModelDeploymentWorkflowIdError(Exception):
81+
class MissingModelDeploymentWorkflowIdError(Exception): # pragma: no cover
8282
pass
8383

8484

@@ -188,13 +188,20 @@ def activate(
188188
if wait_for_completion:
189189

190190
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
191+
oci_model_deployment_object = self.client.get_model_deployment(self.id).data
192+
current_state = State._from_str(oci_model_deployment_object.lifecycle_state)
193+
model_deployment_id = self.id
191194

192195
try:
193-
self._wait_for_work_request(
194-
self.workflow_req_id,
195-
ACTIVATE_WORKFLOW_STEPS,
196-
max_wait_time,
197-
poll_interval
196+
self._wait_for_progress_completion(
197+
State.ACTIVE.name,
198+
ACTIVATE_WORKFLOW_STEPS,
199+
[State.FAILED.name, State.INACTIVE.name],
200+
self.workflow_req_id,
201+
current_state,
202+
model_deployment_id,
203+
max_wait_time,
204+
poll_interval,
198205
)
199206
except Exception as e:
200207
logger.error(
@@ -236,13 +243,20 @@ def create(
236243
if wait_for_completion:
237244

238245
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
246+
res_payload = json.loads(str(response.data))
247+
current_state = State._from_str(res_payload["lifecycle_state"])
248+
model_deployment_id = self.id
239249

240250
try:
241-
self._wait_for_work_request(
242-
self.workflow_req_id,
243-
CREATE_WORKFLOW_STEPS,
244-
max_wait_time,
245-
poll_interval
251+
self._wait_for_progress_completion(
252+
State.ACTIVE.name,
253+
CREATE_WORKFLOW_STEPS,
254+
[State.FAILED.name, State.INACTIVE.name],
255+
self.workflow_req_id,
256+
current_state,
257+
model_deployment_id,
258+
max_wait_time,
259+
poll_interval,
246260
)
247261
except Exception as e:
248262
logger.error(
@@ -287,13 +301,20 @@ def deactivate(
287301
if wait_for_completion:
288302

289303
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
304+
oci_model_deployment_object = self.client.get_model_deployment(self.id).data
305+
current_state = State._from_str(oci_model_deployment_object.lifecycle_state)
306+
model_deployment_id = self.id
290307

291308
try:
292-
self._wait_for_work_request(
293-
self.workflow_req_id,
294-
DEACTIVATE_WORKFLOW_STEPS,
295-
max_wait_time,
296-
poll_interval
309+
self._wait_for_progress_completion(
310+
State.INACTIVE.name,
311+
DEACTIVATE_WORKFLOW_STEPS,
312+
[State.FAILED.name],
313+
self.workflow_req_id,
314+
current_state,
315+
model_deployment_id,
316+
max_wait_time,
317+
poll_interval,
297318
)
298319
except Exception as e:
299320
logger.error(
@@ -338,13 +359,20 @@ def delete(
338359
if wait_for_completion:
339360

340361
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
362+
oci_model_deployment_object = self.client.get_model_deployment(self.id).data
363+
current_state = State._from_str(oci_model_deployment_object.lifecycle_state)
364+
model_deployment_id = self.id
341365

342366
try:
343-
self._wait_for_work_request(
344-
self.workflow_req_id,
345-
DELETE_WORKFLOW_STEPS,
346-
max_wait_time,
347-
poll_interval
367+
self._wait_for_progress_completion(
368+
State.DELETED.name,
369+
DELETE_WORKFLOW_STEPS,
370+
[State.FAILED.name, State.INACTIVE.name],
371+
self.workflow_req_id,
372+
current_state,
373+
model_deployment_id,
374+
max_wait_time,
375+
poll_interval,
348376
)
349377
except Exception as e:
350378
logger.error(
@@ -484,76 +512,89 @@ def from_id(cls, model_deployment_id: str) -> "OCIDataScienceModelDeployment":
484512
"""
485513
return super().from_ocid(model_deployment_id)
486514

487-
def _wait_for_work_request(
488-
self,
489-
work_request_id: str,
490-
num_steps: int = DELETE_WORKFLOW_STEPS,
491-
max_wait_time: int = DEFAULT_WAIT_TIME,
492-
poll_interval: int = DEFAULT_POLL_INTERVAL
493-
) -> None:
494-
"""Waits for the work request to be completed.
515+
def _wait_for_progress_completion(
516+
self,
517+
final_state: str,
518+
work_flow_step: int,
519+
disallowed_final_states: List[str],
520+
work_flow_request_id: str,
521+
state: State,
522+
model_deployment_id: str,
523+
max_wait_time: int = DEFAULT_WAIT_TIME,
524+
poll_interval: int = DEFAULT_POLL_INTERVAL,
525+
):
526+
"""_wait_for_progress_completion blocks until progress is completed.
495527
496528
Parameters
497529
----------
498-
work_request_id: str
499-
Work Request OCID.
500-
num_steps: (int, optional). Defaults to 6.
501-
Number of steps for the progress indicator.
530+
final_state: str
531+
Final state of model deployment aimed to be reached.
532+
work_flow_step: int
533+
Number of work flow step of the request.
534+
disallowed_final_states: list[str]
535+
List of disallowed final state to be reached.
536+
work_flow_request_id: str
537+
The id of work flow request.
538+
state: State
539+
The current state of model deployment.
540+
model_deployment_id: str
541+
The ocid of model deployment.
502542
max_wait_time: int
503543
Maximum amount of time to wait in seconds (Defaults to 1200).
504544
Negative implies infinite wait time.
505545
poll_interval: int
506546
Poll interval in seconds (Defaults to 10).
507-
508-
Returns
509-
-------
510-
None
511547
"""
512-
STOP_STATE = (
513-
WorkRequest.STATUS_SUCCEEDED,
514-
WorkRequest.STATUS_CANCELED,
515-
WorkRequest.STATUS_FAILED,
516-
)
517-
work_request_logs = []
518548

519-
i = 0
520549
start_time = time.time()
521-
with progress_bar_utils.get_progress_bar(num_steps) as progress:
522-
exceed_max_time = max_wait_time > 0 and utils.seconds_since(start_time) >= max_wait_time
523-
if exceed_max_time:
524-
logger.error(
550+
prev_message = ""
551+
prev_workflow_stage_len = 0
552+
current_state = state or State.UNKNOWN
553+
with progress_bar_utils.get_progress_bar(work_flow_step) as progress:
554+
if max_wait_time > 0 and utils.seconds_since(start_time) >= max_wait_time:
555+
utils.get_logger().error(
525556
f"Max wait time ({max_wait_time} seconds) exceeded."
526557
)
527-
while not exceed_max_time and (not work_request_logs or len(work_request_logs) < num_steps):
528-
time.sleep(poll_interval)
529-
new_work_request_logs = []
530-
558+
while (
559+
max_wait_time < 0 or utils.seconds_since(start_time) < max_wait_time
560+
) and current_state.name.upper() != final_state:
561+
if current_state.name.upper() in disallowed_final_states:
562+
utils.get_logger().info(
563+
f"Operation failed due to deployment reaching state {current_state.name.upper()}. Use Deployment ID for further steps."
564+
)
565+
break
566+
567+
prev_state = current_state.name
531568
try:
532-
work_request = self.client.get_work_request(work_request_id).data
533-
work_request_logs = self.client.list_work_request_logs(
534-
work_request_id
569+
model_deployment_payload = json.loads(
570+
str(self.client.get_model_deployment(model_deployment_id).data)
571+
)
572+
current_state = (
573+
State._from_str(model_deployment_payload["lifecycle_state"])
574+
if "lifecycle_state" in model_deployment_payload
575+
else State.UNKNOWN
576+
)
577+
workflow_payload = self.client.list_work_request_logs(
578+
work_flow_request_id
535579
).data
536-
except Exception as ex:
537-
logger.warn(ex)
538-
539-
new_work_request_logs = (
540-
work_request_logs[i:] if work_request_logs else []
541-
)
542-
543-
for wr_item in new_work_request_logs:
544-
progress.update(wr_item.message)
545-
i += 1
546-
547-
if work_request and work_request.status in STOP_STATE:
548-
if work_request.status != WorkRequest.STATUS_SUCCEEDED:
549-
if new_work_request_logs:
550-
raise Exception(new_work_request_logs[-1].message)
551-
else:
552-
raise Exception(
553-
"Error occurred in attempt to perform the operation. "
554-
"Check the service logs to get more details. "
555-
f"{work_request}"
556-
)
557-
else:
558-
break
559-
progress.update("Done")
580+
if isinstance(workflow_payload, list) and len(workflow_payload) > 0:
581+
if prev_message != workflow_payload[-1].message:
582+
for _ in range(
583+
len(workflow_payload) - prev_workflow_stage_len
584+
):
585+
progress.update(workflow_payload[-1].message)
586+
prev_workflow_stage_len = len(workflow_payload)
587+
prev_message = workflow_payload[-1].message
588+
prev_workflow_stage_len = len(workflow_payload)
589+
if prev_state != current_state.name:
590+
utils.get_logger().info(
591+
f"Status Update: {current_state.name} in {utils.seconds_since(start_time)} seconds"
592+
)
593+
except Exception as e:
594+
# utils.get_logger().warning(
595+
# "Unable to update deployment status. Details: %s", format(
596+
# e)
597+
# )
598+
pass
599+
time.sleep(poll_interval)
600+
progress.update("Done")

0 commit comments

Comments
 (0)