Skip to content

Commit 1aa0c47

Browse files
committed
Updated pr.
1 parent 672bbac commit 1aa0c47

File tree

6 files changed

+222
-184
lines changed

6 files changed

+222
-184
lines changed

ads/common/work_request.py

Lines changed: 48 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8; -*-
33

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
4+
# Copyright (c) 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
import logging
@@ -26,7 +26,7 @@
2626
DEFAULT_BAR_FORMAT = '{l_bar}{bar}| [{elapsed}<{remaining}, ' '{rate_fmt}{postfix}]'
2727

2828

29-
class ADSWorkRequest(OCIDataScienceMixin):
29+
class DataScienceWorkRequest(OCIDataScienceMixin):
3030
"""Class for monitoring OCI WorkRequest and representing on tqdm progress bar. This class inherits
3131
`OCIDataScienceMixin` so as to call its `client` attribute to interact with OCI backend.
3232
"""
@@ -124,7 +124,7 @@ def watch(
124124
percent_change = self._percentage - previous_percent_complete
125125
previous_percent_complete = self._percentage
126126
progress_callback(
127-
progress=percent_change,
127+
percent_change=percent_change,
128128
description=self._description
129129
)
130130

@@ -141,53 +141,49 @@ def watch(
141141
else:
142142
break
143143

144-
progress_callback(progress=0, description="Done")
145-
146-
147-
def wait_work_request(
148-
id: str,
149-
progress_bar_description: str="Processing",
150-
max_wait_time: int=DEFAULT_WAIT_TIME,
151-
poll_interval: int=DEFAULT_POLL_INTERVAL
152-
):
153-
"""Waits for the work request progress bar to be completed.
154-
155-
Parameters
156-
----------
157-
id: str
158-
Work Request OCID.
159-
progress_bar_description: str
160-
Progress bar initial step description (Defaults to `Processing`).
161-
max_wait_time: int
162-
Maximum amount of time to wait in seconds (Defaults to 1200).
163-
Negative implies infinite wait time.
164-
poll_interval: int
165-
Poll interval in seconds (Defaults to 10).
166-
167-
Returns
168-
-------
169-
None
170-
"""
171-
ads_work_request = ADSWorkRequest(id)
172-
173-
with tqdm(
174-
total=WORK_REQUEST_PERCENTAGE,
175-
leave=False,
176-
mininterval=0,
177-
file=sys.stdout,
178-
desc=progress_bar_description,
179-
bar_format=DEFAULT_BAR_FORMAT
180-
) as pbar:
181-
182-
def progress_callback(percent_change, description):
183-
if percent_change != 0:
184-
pbar.update(percent_change)
185-
if description:
186-
pbar.set_description(description)
187-
188-
ads_work_request.watch(
189-
progress_callback,
190-
max_wait_time,
191-
poll_interval
192-
)
144+
progress_callback(percent_change=0, description="Done")
145+
146+
def wait_work_request(
147+
self,
148+
progress_bar_description: str="Processing",
149+
max_wait_time: int=DEFAULT_WAIT_TIME,
150+
poll_interval: int=DEFAULT_POLL_INTERVAL
151+
):
152+
"""Waits for the work request progress bar to be completed.
153+
154+
Parameters
155+
----------
156+
progress_bar_description: str
157+
Progress bar initial step description (Defaults to `Processing`).
158+
max_wait_time: int
159+
Maximum amount of time to wait in seconds (Defaults to 1200).
160+
Negative implies infinite wait time.
161+
poll_interval: int
162+
Poll interval in seconds (Defaults to 10).
163+
164+
Returns
165+
-------
166+
None
167+
"""
168+
169+
with tqdm(
170+
total=WORK_REQUEST_PERCENTAGE,
171+
leave=False,
172+
mininterval=0,
173+
file=sys.stdout,
174+
desc=progress_bar_description,
175+
bar_format=DEFAULT_BAR_FORMAT
176+
) as pbar:
177+
178+
def progress_callback(percent_change, description):
179+
if percent_change != 0:
180+
pbar.update(percent_change)
181+
if description:
182+
pbar.set_description(description)
183+
184+
self.watch(
185+
progress_callback=progress_callback,
186+
max_wait_time=max_wait_time,
187+
poll_interval=poll_interval
188+
)
193189

ads/model/service/oci_datascience_model.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8; -*-
33

4-
# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
4+
# Copyright (c) 2022, 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
import logging
@@ -17,7 +17,7 @@
1717
from ads.common.oci_mixin import OCIWorkRequestMixin
1818
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
1919
from ads.common.utils import extract_region
20-
from ads.common.work_request import wait_work_request
20+
from ads.common.work_request import DataScienceWorkRequest
2121
from ads.model.deployment import ModelDeployment
2222
from oci.data_science.models import (
2323
ArtifactExportDetailsObjectStorage,
@@ -362,7 +362,9 @@ def import_model_artifact(self, bucket_uri: str, region: str = None) -> None:
362362
).headers["opc-work-request-id"]
363363

364364
# Show progress of importing artifacts
365-
wait_work_request(work_request_id)
365+
DataScienceWorkRequest(work_request_id).wait_work_request(
366+
progress_bar_description="Importing model artifacts."
367+
)
366368
except ServiceError as ex:
367369
if ex.status == 404:
368370
raise ModelArtifactNotFoundError()
@@ -406,7 +408,9 @@ def export_model_artifact(self, bucket_uri: str, region: str = None):
406408
).headers["opc-work-request-id"]
407409

408410
# Show progress of exporting model artifacts
409-
wait_work_request(work_request_id)
411+
DataScienceWorkRequest(work_request_id).wait_work_request(
412+
progress_bar_description="Exporting model artifacts."
413+
)
410414

411415
@check_for_model_id(
412416
msg="Model needs to be saved to the Model Catalog before it can be updated."

ads/model/service/oci_datascience_model_deployment.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8; -*-
33

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
4+
# Copyright (c) 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
from functools import wraps
88
import logging
99
from typing import Callable, List
1010
from ads.common.oci_datascience import OCIDataScienceMixin
11-
from ads.common.work_request import wait_work_request
11+
from ads.common.work_request import DataScienceWorkRequest
1212
from ads.config import PROJECT_OCID
1313
from ads.model.deployment.common.utils import OCIClientManager, State
1414
import oci
@@ -189,11 +189,10 @@ def activate(
189189
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
190190

191191
try:
192-
wait_work_request(
193-
self.workflow_req_id,
194-
"Activating model deployment",
195-
max_wait_time,
196-
poll_interval
192+
DataScienceWorkRequest(self.workflow_req_id).wait_work_request(
193+
progress_bar_description="Activating model deployment",
194+
max_wait_time=max_wait_time,
195+
poll_interval=poll_interval
197196
)
198197
except Exception as e:
199198
logger.error(
@@ -240,11 +239,10 @@ def create(
240239
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
241240

242241
try:
243-
wait_work_request(
244-
self.workflow_req_id,
245-
"Creating model deployment",
246-
max_wait_time,
247-
poll_interval
242+
DataScienceWorkRequest(self.workflow_req_id).wait_work_request(
243+
progress_bar_description="Creating model deployment",
244+
max_wait_time=max_wait_time,
245+
poll_interval=poll_interval
248246
)
249247
except Exception as e:
250248
logger.error("Error while trying to create model deployment: " + str(e))
@@ -294,11 +292,10 @@ def deactivate(
294292
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
295293

296294
try:
297-
wait_work_request(
298-
self.workflow_req_id,
299-
"Deactivating model deployment",
300-
max_wait_time,
301-
poll_interval
295+
DataScienceWorkRequest(self.workflow_req_id).wait_work_request(
296+
progress_bar_description="Deactivating model deployment",
297+
max_wait_time=max_wait_time,
298+
poll_interval=poll_interval
302299
)
303300
except Exception as e:
304301
logger.error(
@@ -363,11 +360,10 @@ def delete(
363360
self.workflow_req_id = response.headers.get("opc-work-request-id", None)
364361

365362
try:
366-
wait_work_request(
367-
self.workflow_req_id,
368-
"Deleting model deployment",
369-
max_wait_time,
370-
poll_interval
363+
DataScienceWorkRequest(self.workflow_req_id).wait_work_request(
364+
progress_bar_description="Deleting model deployment",
365+
max_wait_time=max_wait_time,
366+
poll_interval=poll_interval
371367
)
372368
except Exception as e:
373369
logger.error("Error while trying to delete model deployment: " + str(e))

tests/unitary/default_setup/common/test_work_request.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
#!/usr/bin/env python
22

3-
# Copyright (c) 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2024 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
import pytest
77
from unittest.mock import MagicMock, patch
8-
from ads.common.work_request import ADSWorkRequest
8+
from ads.common.work_request import DataScienceWorkRequest
99

1010

11-
class TestADSWorkRequest:
11+
class TestDataScienceWorkRequest:
1212

13-
@patch("ads.common.work_request.ADSWorkRequest._sync")
13+
@patch("ads.common.work_request.DataScienceWorkRequest._sync")
1414
@patch("ads.common.oci_datascience.OCIDataScienceMixin.__init__")
1515
def test_watch_succeed(self, mock_oci_datascience, mock_sync):
16-
ads_work_request = ADSWorkRequest(
16+
ads_work_request = DataScienceWorkRequest(
1717
id="test_id",
1818
description = "Processing"
1919
)
@@ -26,10 +26,10 @@ def test_watch_succeed(self, mock_oci_datascience, mock_sync):
2626
mock_oci_datascience.assert_called()
2727
mock_sync.assert_called()
2828

29-
@patch("ads.common.work_request.ADSWorkRequest._sync")
29+
@patch("ads.common.work_request.DataScienceWorkRequest._sync")
3030
@patch("ads.common.oci_datascience.OCIDataScienceMixin.__init__")
3131
def test_watch_failed_with_description(self, mock_oci_datascience, mock_sync):
32-
ads_work_request = ADSWorkRequest(
32+
ads_work_request = DataScienceWorkRequest(
3333
id="test_id",
3434
description = "Backend Error"
3535
)
@@ -43,10 +43,10 @@ def test_watch_failed_with_description(self, mock_oci_datascience, mock_sync):
4343
mock_oci_datascience.assert_called()
4444
mock_sync.assert_called()
4545

46-
@patch("ads.common.work_request.ADSWorkRequest._sync")
46+
@patch("ads.common.work_request.DataScienceWorkRequest._sync")
4747
@patch("ads.common.oci_datascience.OCIDataScienceMixin.__init__")
4848
def test_watch_failed_without_description(self, mock_oci_datascience, mock_sync):
49-
ads_work_request = ADSWorkRequest(
49+
ads_work_request = DataScienceWorkRequest(
5050
id="test_id",
5151
description = None
5252
)
@@ -64,3 +64,17 @@ def test_watch_failed_without_description(self, mock_oci_datascience, mock_sync)
6464
)
6565
mock_oci_datascience.assert_called()
6666
mock_sync.assert_called()
67+
68+
69+
@patch("ads.common.work_request.DataScienceWorkRequest._sync")
70+
@patch("ads.common.oci_datascience.OCIDataScienceMixin.__init__")
71+
def test_wait_work_request(self, mock_oci_datascience, mock_sync):
72+
ads_work_request = DataScienceWorkRequest(
73+
id="test_id",
74+
description = None
75+
)
76+
ads_work_request._percentage = 90
77+
ads_work_request._status = "SUCCEEDED"
78+
ads_work_request.wait_work_request(poll_interval=0)
79+
mock_oci_datascience.assert_called()
80+
mock_sync.assert_called()

tests/unitary/default_setup/model/test_oci_datascience_model.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22

3-
# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2022, 2024 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
from unittest.mock import MagicMock, patch, call, PropertyMock
@@ -375,16 +375,19 @@ def test_model_deployment(self, mock_search):
375375
**{"kwargkey": "kwargvalue"},
376376
)
377377

378-
@patch("ads.model.service.oci_datascience_model.wait_work_request")
378+
@patch("ads.model.service.oci_datascience_model.DataScienceWorkRequest.wait_work_request")
379+
@patch("ads.model.service.oci_datascience_model.DataScienceWorkRequest.__init__")
379380
def test_import_model_artifact_success(
380381
self,
381-
mock_wait_for_work_request,
382+
mock_data_science_work_request,
383+
mock_wait_work_request,
382384
mock_client,
383385
):
384386
"""Tests importing model artifact content from the model catalog."""
385387
test_bucket_uri = "oci://bucket@namespace/prefix"
386388
test_bucket_details = ObjectStorageDetails.from_path(test_bucket_uri)
387389
test_region = "test_region"
390+
mock_data_science_work_request.return_value = None
388391
with patch.object(OCIDataScienceModel, "client", mock_client):
389392
self.mock_model.import_model_artifact(
390393
bucket_uri=test_bucket_uri, region=test_region
@@ -400,7 +403,10 @@ def test_import_model_artifact_success(
400403
)
401404
),
402405
)
403-
mock_wait_for_work_request.assert_called_with("work_request_id")
406+
mock_data_science_work_request.assert_called_with("work_request_id")
407+
mock_wait_work_request.assert_called_with(
408+
progress_bar_description='Importing model artifacts.'
409+
)
404410

405411
@patch.object(OCIDataScienceModel, "client")
406412
def test_import_model_artifact_fail(self, mock_client):
@@ -416,16 +422,19 @@ def test_import_model_artifact_fail(self, mock_client):
416422
bucket_uri=test_bucket_uri, region="test_region"
417423
)
418424

419-
@patch("ads.model.service.oci_datascience_model.wait_work_request")
425+
@patch("ads.model.service.oci_datascience_model.DataScienceWorkRequest.wait_work_request")
426+
@patch("ads.model.service.oci_datascience_model.DataScienceWorkRequest.__init__")
420427
def test_export_model_artifact(
421428
self,
422-
mock_wait_for_work_request,
429+
mock_data_science_work_request,
430+
mock_wait_work_request,
423431
mock_client,
424432
):
425433
"""Tests exporting model artifact to the model catalog."""
426434
test_bucket_uri = "oci://bucket@namespace/prefix"
427435
test_bucket_details = ObjectStorageDetails.from_path(test_bucket_uri)
428436
test_region = "test_region"
437+
mock_data_science_work_request.return_value = None
429438
with patch.object(OCIDataScienceModel, "client", mock_client):
430439
self.mock_model.export_model_artifact(
431440
bucket_uri=test_bucket_uri, region=test_region
@@ -441,4 +450,7 @@ def test_export_model_artifact(
441450
)
442451
),
443452
)
444-
mock_wait_for_work_request.assert_called_with("work_request_id")
453+
mock_data_science_work_request.assert_called_with("work_request_id")
454+
mock_wait_work_request.assert_called_with(
455+
progress_bar_description='Exporting model artifacts.'
456+
)

0 commit comments

Comments
 (0)