Skip to content

Commit 7e8bd6b

Browse files
committed
Updated pr and fixed unit tests.
1 parent 8e24e15 commit 7e8bd6b

File tree

4 files changed

+23
-213
lines changed

4 files changed

+23
-213
lines changed

ads/common/oci_mixin.py

Lines changed: 1 addition & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,16 @@
1111
import logging
1212
import os
1313
import re
14-
import sys
15-
import time
1614
import traceback
1715
from datetime import date, datetime
1816
from typing import Callable, Optional, Union
1917
from enum import Enum
2018

2119
import oci
22-
from tqdm import tqdm
2320
import yaml
2421
from ads.common import auth
2522
from ads.common.decorator.utils import class_or_instance_method
26-
from ads.common.utils import camel_to_snake, get_progress_bar
23+
from ads.common.utils import camel_to_snake
2724
from ads.config import COMPARTMENT_OCID
2825
from dateutil import tz
2926
from dateutil.parser import parse
@@ -938,88 +935,6 @@ def get_work_request_response(
938935
)
939936
return work_request_response
940937

941-
def wait_for_progress(
942-
self,
943-
work_request_id: str,
944-
max_wait_time: int = DEFAULT_WAIT_TIME,
945-
poll_interval: int = DEFAULT_POLL_INTERVAL,
946-
):
947-
"""Waits for the work request progress bar to be completed.
948-
949-
Parameters
950-
----------
951-
work_request_id: str
952-
Work Request OCID.
953-
max_wait_time: int
954-
Maximum amount of time to wait in seconds (Defaults to 1200).
955-
Negative implies infinite wait time.
956-
poll_interval: int
957-
Poll interval in seconds (Defaults to 10).
958-
959-
Returns
960-
-------
961-
None
962-
"""
963-
work_request_logs = []
964-
965-
i = 0
966-
start_time = time.time()
967-
with get_progress_bar(WORK_REQUEST_PERCENTAGE) as progress:
968-
seconds_since = time.time() - start_time
969-
exceed_max_time = max_wait_time > 0 and seconds_since >= max_wait_time
970-
if exceed_max_time:
971-
logger.error(f"Max wait time ({max_wait_time} seconds) exceeded.")
972-
previous_percent_complete = 0
973-
while not exceed_max_time and (
974-
not work_request_logs or previous_percent_complete <= WORK_REQUEST_PERCENTAGE
975-
):
976-
time.sleep(poll_interval)
977-
new_work_request_logs = []
978-
979-
try:
980-
work_request = self.client.get_work_request(work_request_id).data
981-
work_request_logs = self.client.list_work_request_logs(
982-
work_request_id
983-
).data
984-
except Exception as ex:
985-
logger.warn(ex)
986-
987-
new_work_request_logs = (
988-
work_request_logs[i:] if work_request_logs else []
989-
)
990-
991-
percent_change = work_request.percent_complete - previous_percent_complete
992-
previous_percent_complete = work_request.percent_complete
993-
994-
if len(new_work_request_logs) > 0:
995-
start_index = True
996-
for wr_item in new_work_request_logs:
997-
if start_index:
998-
progress.update(wr_item.message, percent_change)
999-
start_index = False
1000-
else:
1001-
progress.update(wr_item.message, 0)
1002-
i += 1
1003-
else:
1004-
# if there is new percent change but the new work request logs is empty
1005-
# needs to add this percent change to the bar to ensure the final percentage is 100
1006-
if percent_change != 0:
1007-
progress.update(n=percent_change)
1008-
1009-
if work_request and work_request.status in WORK_REQUEST_STOP_STATE:
1010-
if work_request.status != "SUCCEEDED":
1011-
if new_work_request_logs:
1012-
raise Exception(new_work_request_logs[-1].message)
1013-
else:
1014-
raise Exception(
1015-
"Error occurred in attempt to perform the operation. "
1016-
"Check the service logs to get more details. "
1017-
f"{work_request}"
1018-
)
1019-
else:
1020-
break
1021-
progress.update("Done")
1022-
1023938

1024939
class OCIModelWithNameMixin:
1025940
"""Mixin class to operate OCI model which contains name property."""

ads/common/work_request.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8; -*-
3+
4+
# Copyright (c) 2023 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
16

27
import logging
38
import sys

tests/unitary/default_setup/model/test_oci_datascience_model.py

Lines changed: 4 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def test_model_deployment(self, mock_search):
375375
**{"kwargkey": "kwargvalue"},
376376
)
377377

378-
@patch.object(OCIDataScienceModel, "_wait_for_work_request")
378+
@patch("ads.model.service.oci_datascience_model.wait_work_request")
379379
def test_import_model_artifact_success(
380380
self,
381381
mock_wait_for_work_request,
@@ -400,10 +400,7 @@ def test_import_model_artifact_success(
400400
)
401401
),
402402
)
403-
mock_wait_for_work_request.assert_called_with(
404-
work_request_id="work_request_id",
405-
num_steps=2,
406-
)
403+
mock_wait_for_work_request.assert_called_with("work_request_id")
407404

408405
@patch.object(OCIDataScienceModel, "client")
409406
def test_import_model_artifact_fail(self, mock_client):
@@ -419,7 +416,7 @@ def test_import_model_artifact_fail(self, mock_client):
419416
bucket_uri=test_bucket_uri, region="test_region"
420417
)
421418

422-
@patch.object(OCIDataScienceModel, "_wait_for_work_request")
419+
@patch("ads.model.service.oci_datascience_model.wait_work_request")
423420
def test_export_model_artifact(
424421
self,
425422
mock_wait_for_work_request,
@@ -444,115 +441,4 @@ def test_export_model_artifact(
444441
)
445442
),
446443
)
447-
mock_wait_for_work_request.assert_called_with(
448-
work_request_id="work_request_id",
449-
num_steps=2,
450-
)
451-
452-
@patch.object(TqdmProgressBar, "update")
453-
def test__wait_for_work_request_fail(self, mock_tqdm_update, mock_client):
454-
mock_client.get_work_request = MagicMock(
455-
return_value=Response(
456-
data=WorkRequest(id="work_request_id", status="FAILED"),
457-
status=None,
458-
headers={"opc-work-request-id": "work_request_id"},
459-
request=None,
460-
)
461-
)
462-
mock_client.list_work_request_logs = MagicMock(
463-
return_value=Response(
464-
data=[
465-
WorkRequestLogEntry(message="test_message_1"),
466-
WorkRequestLogEntry(message="error_message_1"),
467-
],
468-
status=None,
469-
headers=None,
470-
request=None,
471-
)
472-
)
473-
with patch.object(
474-
OCIDataScienceModel,
475-
"client",
476-
new_callable=PropertyMock,
477-
return_value=mock_client,
478-
):
479-
with pytest.raises(Exception, match="error_message_1"):
480-
self.mock_model._wait_for_work_request(
481-
work_request_id="work_request_id", num_steps=2
482-
)
483-
mock_tqdm_update.assert_has_calls(
484-
[
485-
call("test_message_1"),
486-
call("error_message_1"),
487-
]
488-
)
489-
assert mock_tqdm_update.call_count == 2
490-
491-
@patch.object(TqdmProgressBar, "update")
492-
def test__wait_for_work_request_fail_generic(self, mock_tqdm_update, mock_client):
493-
mock_client.get_work_request = MagicMock(
494-
return_value=Response(
495-
data=WorkRequest(id="work_request_id", status="FAILED"),
496-
status=None,
497-
headers={"opc-work-request-id": "work_request_id"},
498-
request=None,
499-
)
500-
)
501-
mock_client.list_work_request_logs = MagicMock(
502-
return_value=Response(
503-
data=[],
504-
status=None,
505-
headers=None,
506-
request=None,
507-
)
508-
)
509-
with patch.object(
510-
OCIDataScienceModel,
511-
"client",
512-
new_callable=PropertyMock,
513-
return_value=mock_client,
514-
):
515-
with pytest.raises(
516-
Exception, match="^Error occurred in attempt to perform the operation*"
517-
):
518-
self.mock_model._wait_for_work_request(
519-
work_request_id="work_request_id", num_steps=2
520-
)
521-
mock_tqdm_update.assert_not_called()
522-
523-
@patch.object(TqdmProgressBar, "update")
524-
def test__wait_for_work_request_success(self, mock_tqdm_update, mock_client):
525-
mock_client.get_work_request = MagicMock(
526-
return_value=Response(
527-
data=WorkRequest(id="work_request_id", status="SUCCEEDED"),
528-
status=None,
529-
headers={"opc-work-request-id": "work_request_id"},
530-
request=None,
531-
)
532-
)
533-
mock_client.list_work_request_logs = MagicMock(
534-
return_value=Response(
535-
data=[
536-
WorkRequestLogEntry(message="test_message_1"),
537-
WorkRequestLogEntry(message="test_message_2"),
538-
],
539-
status=None,
540-
headers=None,
541-
request=None,
542-
)
543-
)
544-
with patch.object(
545-
OCIDataScienceModel,
546-
"client",
547-
new_callable=PropertyMock,
548-
return_value=mock_client,
549-
):
550-
self.mock_model._wait_for_work_request(
551-
work_request_id="work_request_id", num_steps=2
552-
)
553-
# mock_tqdm_update.assert_has_calls(
554-
# [call("test_message_1"), call("test_message_2")]
555-
# )
556-
# assert mock_tqdm_update.call_count == 2
557-
# mock_tqdm_update.assert_called()
558-
# assert mock_tqdm_update.call_count == 2
444+
mock_wait_for_work_request.assert_called_with("work_request_id")

tests/unitary/default_setup/model_deployment/test_oci_datascience_model_deployment.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
ModelDeployment,
1313
)
1414
from ads.common.oci_datascience import OCIDataScienceMixin
15-
from ads.common.oci_mixin import OCIModelMixin, OCIWorkRequestMixin
15+
from ads.common.oci_mixin import OCIModelMixin
1616

1717
from ads.model.service.oci_datascience_model_deployment import (
1818
OCIDataScienceModelDeployment,
@@ -145,8 +145,8 @@ def test_activate_with_waiting(self):
145145
"opc-work-request-id": "test",
146146
}
147147
mock_activate.return_value = response
148-
with patch.object(
149-
OCIWorkRequestMixin, "wait_for_progress"
148+
with patch(
149+
"ads.model.service.oci_datascience_model_deployment.wait_work_request"
150150
) as mock_wait:
151151
with patch.object(
152152
OCIDataScienceModelDeployment, "sync"
@@ -159,6 +159,7 @@ def test_activate_with_waiting(self):
159159
mock_activate.assert_called_with(self.mock_model_deployment.id)
160160
mock_wait.assert_called_with(
161161
"test",
162+
"Activating model deployment",
162163
1,
163164
1,
164165
)
@@ -227,8 +228,8 @@ def test_deactivate_with_waiting(self):
227228
"opc-work-request-id": "test",
228229
}
229230
mock_deactivate.return_value = response
230-
with patch.object(
231-
OCIWorkRequestMixin, "wait_for_progress"
231+
with patch(
232+
"ads.model.service.oci_datascience_model_deployment.wait_work_request"
232233
) as mock_wait:
233234
with patch.object(
234235
OCIDataScienceModelDeployment, "sync"
@@ -243,6 +244,7 @@ def test_deactivate_with_waiting(self):
243244
)
244245
mock_wait.assert_called_with(
245246
"test",
247+
"Deactivating model deployment",
246248
1,
247249
1,
248250
)
@@ -300,8 +302,8 @@ def test_create_with_waiting(self):
300302
**OCI_MODEL_DEPLOYMENT_PAYLOAD
301303
)
302304
mock_to_oci_mode.return_value = oci_model_deployment
303-
with patch.object(
304-
OCIWorkRequestMixin, "wait_for_progress"
305+
with patch(
306+
"ads.model.service.oci_datascience_model_deployment.wait_work_request"
305307
) as mock_wait:
306308
with patch("json.loads") as mock_json_load:
307309
create_model_deployment_details = MagicMock()
@@ -324,6 +326,7 @@ def test_create_with_waiting(self):
324326
mock_update_from_oci_model.assert_called()
325327
mock_wait.assert_called_with(
326328
"test",
329+
"Creating model deployment",
327330
1,
328331
1,
329332
)
@@ -439,8 +442,8 @@ def test_delete_with_waiting(self):
439442
mock_from_id.return_value = OCIDataScienceModelDeployment(
440443
**OCI_MODEL_DEPLOYMENT_PAYLOAD
441444
)
442-
with patch.object(
443-
OCIWorkRequestMixin, "wait_for_progress"
445+
with patch(
446+
"ads.model.service.oci_datascience_model_deployment.wait_work_request"
444447
) as mock_wait:
445448
with patch.object(
446449
oci.data_science.DataScienceClient,
@@ -461,6 +464,7 @@ def test_delete_with_waiting(self):
461464
mock_delete.assert_called_with(self.mock_model_deployment.id)
462465
mock_wait.assert_called_with(
463466
"test",
467+
"Deleting model deployment",
464468
1,
465469
1,
466470
)

0 commit comments

Comments
 (0)