Skip to content

Commit 14307d2

Browse files
committed
Updated pr.
1 parent 65207bd commit 14307d2

File tree

3 files changed

+87
-88
lines changed

3 files changed

+87
-88
lines changed

ads/common/oci_mixin.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import logging
1212
import os
1313
import re
14+
import time
1415
import traceback
1516
from datetime import date, datetime
1617
from typing import Callable, Optional, Union
@@ -20,7 +21,7 @@
2021
import yaml
2122
from ads.common import auth
2223
from ads.common.decorator.utils import class_or_instance_method
23-
from ads.common.utils import camel_to_snake
24+
from ads.common.utils import camel_to_snake, get_progress_bar
2425
from ads.config import COMPARTMENT_OCID
2526
from dateutil import tz
2627
from dateutil.parser import parse
@@ -29,6 +30,10 @@
2930
logger = logging.getLogger(__name__)
3031

3132
LIFECYCLE_STOP_STATE = ("SUCCEEDED", "FAILED", "CANCELED", "DELETED")
33+
WORK_REQUEST_STOP_STATE = ("SUCCEEDED", "FAILED", "CANCELED")
34+
DEFAULT_WAIT_TIME = 1200
35+
DEFAULT_POLL_INTERVAL = 10
36+
DEFAULT_WORKFLOW_STEPS = 2
3237

3338

3439
class MergeStrategy(Enum):
@@ -931,6 +936,76 @@ def get_work_request_response(
931936
)
932937
return work_request_response
933938

939+
def wait_for_progress(
940+
self,
941+
work_request_id: str,
942+
num_steps: int = DEFAULT_WORKFLOW_STEPS,
943+
max_wait_time: int = DEFAULT_WAIT_TIME,
944+
poll_interval: int = DEFAULT_POLL_INTERVAL
945+
):
946+
"""Waits for the work request progress bar to be completed.
947+
948+
Parameters
949+
----------
950+
work_request_id: str
951+
Work Request OCID.
952+
num_steps: (int, optional). Defaults to 2.
953+
Number of steps for the progress indicator.
954+
max_wait_time: int
955+
Maximum amount of time to wait in seconds (Defaults to 1200).
956+
Negative implies infinite wait time.
957+
poll_interval: int
958+
Poll interval in seconds (Defaults to 10).
959+
960+
Returns
961+
-------
962+
None
963+
"""
964+
work_request_logs = []
965+
966+
i = 0
967+
start_time = time.time()
968+
with get_progress_bar(num_steps) as progress:
969+
seconds_since = time.time() - start_time
970+
exceed_max_time = max_wait_time > 0 and seconds_since >= max_wait_time
971+
if exceed_max_time:
972+
logger.error(
973+
f"Max wait time ({max_wait_time} seconds) exceeded."
974+
)
975+
while not exceed_max_time and (not work_request_logs or len(work_request_logs) < num_steps):
976+
time.sleep(poll_interval)
977+
new_work_request_logs = []
978+
979+
try:
980+
work_request = self.client.get_work_request(work_request_id).data
981+
work_request_logs = self.client.list_work_request_logs(
982+
work_request_id
983+
).data
984+
except Exception as ex:
985+
logger.warn(ex)
986+
987+
new_work_request_logs = (
988+
work_request_logs[i:] if work_request_logs else []
989+
)
990+
991+
for wr_item in new_work_request_logs:
992+
progress.update(wr_item.message)
993+
i += 1
994+
995+
if work_request and work_request.status in WORK_REQUEST_STOP_STATE:
996+
if work_request.status != "SUCCEEDED":
997+
if new_work_request_logs:
998+
raise Exception(new_work_request_logs[-1].message)
999+
else:
1000+
raise Exception(
1001+
"Error occurred in attempt to perform the operation. "
1002+
"Check the service logs to get more details. "
1003+
f"{work_request}"
1004+
)
1005+
else:
1006+
break
1007+
progress.update("Done")
1008+
9341009

9351010
class OCIModelWithNameMixin:
9361011
"""Mixin class to operate OCI model which contains name property."""

ads/model/service/oci_datascience_model_deployment.py

Lines changed: 6 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,17 @@
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
from functools import wraps
8-
import time
98
import logging
109
from typing import Callable, List
1110
from ads.common.oci_datascience import OCIDataScienceMixin
12-
from ads.common import utils as progress_bar_utils
11+
from ads.common.oci_mixin import OCIWorkRequestMixin
1312
from ads.config import PROJECT_OCID
14-
from ads.model.deployment.common import utils
1513
from ads.model.deployment.common.utils import OCIClientManager, State
1614
import oci
1715

1816
from oci.data_science.models import (
1917
CreateModelDeploymentDetails,
2018
UpdateModelDeploymentDetails,
21-
WorkRequest
2219
)
2320

2421
DEFAULT_WAIT_TIME = 1200
@@ -84,6 +81,7 @@ class MissingModelDeploymentWorkflowIdError(Exception):
8481

8582
class OCIDataScienceModelDeployment(
8683
OCIDataScienceMixin,
84+
OCIWorkRequestMixin,
8785
oci.data_science.models.ModelDeployment,
8886
):
8987
"""Represents an OCI Data Science Model Deployment.
@@ -190,7 +188,7 @@ def activate(
190188
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
191189

192190
try:
193-
self._wait_for_work_request(
191+
self.wait_for_progress(
194192
self.workflow_req_id,
195193
ACTIVATE_WORKFLOW_STEPS,
196194
max_wait_time,
@@ -238,7 +236,7 @@ def create(
238236
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
239237

240238
try:
241-
self._wait_for_work_request(
239+
self.wait_for_progress(
242240
self.workflow_req_id,
243241
CREATE_WORKFLOW_STEPS,
244242
max_wait_time,
@@ -289,7 +287,7 @@ def deactivate(
289287
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
290288

291289
try:
292-
self._wait_for_work_request(
290+
self.wait_for_progress(
293291
self.workflow_req_id,
294292
DEACTIVATE_WORKFLOW_STEPS,
295293
max_wait_time,
@@ -340,7 +338,7 @@ def delete(
340338
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
341339

342340
try:
343-
self._wait_for_work_request(
341+
self.wait_for_progress(
344342
self.workflow_req_id,
345343
DELETE_WORKFLOW_STEPS,
346344
max_wait_time,
@@ -483,77 +481,3 @@ def from_id(cls, model_deployment_id: str) -> "OCIDataScienceModelDeployment":
483481
An instance of `OCIDataScienceModelDeployment`.
484482
"""
485483
return super().from_ocid(model_deployment_id)
486-
487-
def _wait_for_work_request(
488-
self,
489-
work_request_id: str,
490-
num_steps: int = DELETE_WORKFLOW_STEPS,
491-
max_wait_time: int = DEFAULT_WAIT_TIME,
492-
poll_interval: int = DEFAULT_POLL_INTERVAL
493-
) -> None:
494-
"""Waits for the work request to be completed.
495-
496-
Parameters
497-
----------
498-
work_request_id: str
499-
Work Request OCID.
500-
num_steps: (int, optional). Defaults to 6.
501-
Number of steps for the progress indicator.
502-
max_wait_time: int
503-
Maximum amount of time to wait in seconds (Defaults to 1200).
504-
Negative implies infinite wait time.
505-
poll_interval: int
506-
Poll interval in seconds (Defaults to 10).
507-
508-
Returns
509-
-------
510-
None
511-
"""
512-
STOP_STATE = (
513-
WorkRequest.STATUS_SUCCEEDED,
514-
WorkRequest.STATUS_CANCELED,
515-
WorkRequest.STATUS_FAILED,
516-
)
517-
work_request_logs = []
518-
519-
i = 0
520-
start_time = time.time()
521-
with progress_bar_utils.get_progress_bar(num_steps) as progress:
522-
exceed_max_time = max_wait_time > 0 and utils.seconds_since(start_time) >= max_wait_time
523-
if exceed_max_time:
524-
logger.error(
525-
f"Max wait time ({max_wait_time} seconds) exceeded."
526-
)
527-
while not exceed_max_time and (not work_request_logs or len(work_request_logs) < num_steps):
528-
time.sleep(poll_interval)
529-
new_work_request_logs = []
530-
531-
try:
532-
work_request = self.client.get_work_request(work_request_id).data
533-
work_request_logs = self.client.list_work_request_logs(
534-
work_request_id
535-
).data
536-
except Exception as ex:
537-
logger.warn(ex)
538-
539-
new_work_request_logs = (
540-
work_request_logs[i:] if work_request_logs else []
541-
)
542-
543-
for wr_item in new_work_request_logs:
544-
progress.update(wr_item.message)
545-
i += 1
546-
547-
if work_request and work_request.status in STOP_STATE:
548-
if work_request.status != WorkRequest.STATUS_SUCCEEDED:
549-
if new_work_request_logs:
550-
raise Exception(new_work_request_logs[-1].message)
551-
else:
552-
raise Exception(
553-
"Error occurred in attempt to perform the operation. "
554-
"Check the service logs to get more details. "
555-
f"{work_request}"
556-
)
557-
else:
558-
break
559-
progress.update("Done")

tests/unitary/default_setup/model_deployment/test_oci_datascience_model_deployment.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
ModelDeployment,
1111
)
1212
from ads.common.oci_datascience import OCIDataScienceMixin
13-
from ads.common.oci_mixin import OCIModelMixin
13+
from ads.common.oci_mixin import OCIModelMixin, OCIWorkRequestMixin
1414

1515
from ads.model.service.oci_datascience_model_deployment import (
1616
ACTIVATE_WORKFLOW_STEPS,
@@ -116,7 +116,7 @@ def test_activate_with_waiting(self):
116116
response.data = data
117117
mock_get.return_value = response
118118
with patch.object(
119-
OCIDataScienceModelDeployment, "_wait_for_work_request"
119+
OCIWorkRequestMixin, "wait_for_progress"
120120
) as mock_wait:
121121
with patch.object(
122122
OCIDataScienceModelDeployment, "sync"
@@ -169,7 +169,7 @@ def test_deactivate_with_waiting(self):
169169
response.data = data
170170
mock_get.return_value = response
171171
with patch.object(
172-
OCIDataScienceModelDeployment, "_wait_for_work_request"
172+
OCIWorkRequestMixin, "wait_for_progress"
173173
) as mock_wait:
174174
with patch.object(
175175
OCIDataScienceModelDeployment, "sync"
@@ -243,7 +243,7 @@ def test_create_with_waiting(self):
243243
)
244244
mock_to_oci_mode.return_value = oci_model_deployment
245245
with patch.object(
246-
OCIDataScienceModelDeployment, "_wait_for_work_request"
246+
OCIWorkRequestMixin, "wait_for_progress"
247247
) as mock_wait:
248248
with patch("json.loads") as mock_json_load:
249249
create_model_deployment_details = MagicMock()
@@ -354,7 +354,7 @@ def test_delete_with_waiting(self):
354354
}
355355
mock_delete.return_value = response
356356
with patch.object(
357-
OCIDataScienceModelDeployment, "_wait_for_work_request"
357+
OCIWorkRequestMixin, "wait_for_progress"
358358
) as mock_wait:
359359
with patch.object(
360360
oci.data_science.DataScienceClient, "get_model_deployment"

0 commit comments

Comments
 (0)