|
5 | 5 | # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6 | 6 |
|
7 | 7 | from functools import wraps
|
8 |
| -import json |
9 |
| -import time |
10 | 8 | import logging
|
11 | 9 | from typing import Callable, List
|
12 | 10 | from ads.common.oci_datascience import OCIDataScienceMixin
|
13 |
| -from ads.common import utils as progress_bar_utils |
| 11 | +from ads.common.oci_mixin import OCIWorkRequestMixin |
14 | 12 | from ads.config import PROJECT_OCID
|
15 |
| -from ads.model.deployment.common import utils |
16 | 13 | from ads.model.deployment.common.utils import OCIClientManager, State
|
17 | 14 | import oci
|
18 | 15 |
|
@@ -84,6 +81,7 @@ class MissingModelDeploymentWorkflowIdError(Exception): # pragma: no cover
|
84 | 81 |
|
85 | 82 | class OCIDataScienceModelDeployment(
|
86 | 83 | OCIDataScienceMixin,
|
| 84 | + OCIWorkRequestMixin, |
87 | 85 | oci.data_science.models.ModelDeployment,
|
88 | 86 | ):
|
89 | 87 | """Represents an OCI Data Science Model Deployment.
|
@@ -188,20 +186,13 @@ def activate(
|
188 | 186 | if wait_for_completion:
|
189 | 187 |
|
190 | 188 | self.workflow_req_id = response.headers.get("opc-work-request-id", None)
|
191 |
| - oci_model_deployment_object = self.client.get_model_deployment(self.id).data |
192 |
| - current_state = State._from_str(oci_model_deployment_object.lifecycle_state) |
193 |
| - model_deployment_id = self.id |
194 | 189 |
|
195 | 190 | try:
|
196 |
| - self._wait_for_progress_completion( |
197 |
| - State.ACTIVE.name, |
198 |
| - ACTIVATE_WORKFLOW_STEPS, |
199 |
| - [State.FAILED.name, State.INACTIVE.name], |
200 |
| - self.workflow_req_id, |
201 |
| - current_state, |
202 |
| - model_deployment_id, |
203 |
| - max_wait_time, |
204 |
| - poll_interval, |
| 191 | + self.wait_for_progress( |
| 192 | + self.workflow_req_id, |
| 193 | + ACTIVATE_WORKFLOW_STEPS, |
| 194 | + max_wait_time, |
| 195 | + poll_interval |
205 | 196 | )
|
206 | 197 | except Exception as e:
|
207 | 198 | logger.error(
|
@@ -243,20 +234,13 @@ def create(
|
243 | 234 | if wait_for_completion:
|
244 | 235 |
|
245 | 236 | self.workflow_req_id = response.headers.get("opc-work-request-id", None)
|
246 |
| - res_payload = json.loads(str(response.data)) |
247 |
| - current_state = State._from_str(res_payload["lifecycle_state"]) |
248 |
| - model_deployment_id = self.id |
249 | 237 |
|
250 | 238 | try:
|
251 |
| - self._wait_for_progress_completion( |
252 |
| - State.ACTIVE.name, |
253 |
| - CREATE_WORKFLOW_STEPS, |
254 |
| - [State.FAILED.name, State.INACTIVE.name], |
255 |
| - self.workflow_req_id, |
256 |
| - current_state, |
257 |
| - model_deployment_id, |
258 |
| - max_wait_time, |
259 |
| - poll_interval, |
| 239 | + self.wait_for_progress( |
| 240 | + self.workflow_req_id, |
| 241 | + CREATE_WORKFLOW_STEPS, |
| 242 | + max_wait_time, |
| 243 | + poll_interval |
260 | 244 | )
|
261 | 245 | except Exception as e:
|
262 | 246 | logger.error(
|
@@ -301,20 +285,13 @@ def deactivate(
|
301 | 285 | if wait_for_completion:
|
302 | 286 |
|
303 | 287 | self.workflow_req_id = response.headers.get("opc-work-request-id", None)
|
304 |
| - oci_model_deployment_object = self.client.get_model_deployment(self.id).data |
305 |
| - current_state = State._from_str(oci_model_deployment_object.lifecycle_state) |
306 |
| - model_deployment_id = self.id |
307 | 288 |
|
308 | 289 | try:
|
309 |
| - self._wait_for_progress_completion( |
310 |
| - State.INACTIVE.name, |
311 |
| - DEACTIVATE_WORKFLOW_STEPS, |
312 |
| - [State.FAILED.name], |
313 |
| - self.workflow_req_id, |
314 |
| - current_state, |
315 |
| - model_deployment_id, |
316 |
| - max_wait_time, |
317 |
| - poll_interval, |
| 290 | + self.wait_for_progress( |
| 291 | + self.workflow_req_id, |
| 292 | + DEACTIVATE_WORKFLOW_STEPS, |
| 293 | + max_wait_time, |
| 294 | + poll_interval |
318 | 295 | )
|
319 | 296 | except Exception as e:
|
320 | 297 | logger.error(
|
@@ -359,20 +336,13 @@ def delete(
|
359 | 336 | if wait_for_completion:
|
360 | 337 |
|
361 | 338 | self.workflow_req_id = response.headers.get("opc-work-request-id", None)
|
362 |
| - oci_model_deployment_object = self.client.get_model_deployment(self.id).data |
363 |
| - current_state = State._from_str(oci_model_deployment_object.lifecycle_state) |
364 |
| - model_deployment_id = self.id |
365 | 339 |
|
366 | 340 | try:
|
367 |
| - self._wait_for_progress_completion( |
368 |
| - State.DELETED.name, |
369 |
| - DELETE_WORKFLOW_STEPS, |
370 |
| - [State.FAILED.name, State.INACTIVE.name], |
371 |
| - self.workflow_req_id, |
372 |
| - current_state, |
373 |
| - model_deployment_id, |
374 |
| - max_wait_time, |
375 |
| - poll_interval, |
| 341 | + self.wait_for_progress( |
| 342 | + self.workflow_req_id, |
| 343 | + DELETE_WORKFLOW_STEPS, |
| 344 | + max_wait_time, |
| 345 | + poll_interval |
376 | 346 | )
|
377 | 347 | except Exception as e:
|
378 | 348 | logger.error(
|
@@ -511,90 +481,3 @@ def from_id(cls, model_deployment_id: str) -> "OCIDataScienceModelDeployment":
|
511 | 481 | An instance of `OCIDataScienceModelDeployment`.
|
512 | 482 | """
|
513 | 483 | return super().from_ocid(model_deployment_id)
|
514 |
| - |
515 |
| - def _wait_for_progress_completion( |
516 |
| - self, |
517 |
| - final_state: str, |
518 |
| - work_flow_step: int, |
519 |
| - disallowed_final_states: List[str], |
520 |
| - work_flow_request_id: str, |
521 |
| - state: State, |
522 |
| - model_deployment_id: str, |
523 |
| - max_wait_time: int = DEFAULT_WAIT_TIME, |
524 |
| - poll_interval: int = DEFAULT_POLL_INTERVAL, |
525 |
| - ): |
526 |
| - """_wait_for_progress_completion blocks until progress is completed. |
527 |
| -
|
528 |
| - Parameters |
529 |
| - ---------- |
530 |
| - final_state: str |
531 |
| - Final state of model deployment aimed to be reached. |
532 |
| - work_flow_step: int |
533 |
| - Number of work flow step of the request. |
534 |
| - disallowed_final_states: list[str] |
535 |
| - List of disallowed final state to be reached. |
536 |
| - work_flow_request_id: str |
537 |
| - The id of work flow request. |
538 |
| - state: State |
539 |
| - The current state of model deployment. |
540 |
| - model_deployment_id: str |
541 |
| - The ocid of model deployment. |
542 |
| - max_wait_time: int |
543 |
| - Maximum amount of time to wait in seconds (Defaults to 1200). |
544 |
| - Negative implies infinite wait time. |
545 |
| - poll_interval: int |
546 |
| - Poll interval in seconds (Defaults to 10). |
547 |
| - """ |
548 |
| - |
549 |
| - start_time = time.time() |
550 |
| - prev_message = "" |
551 |
| - prev_workflow_stage_len = 0 |
552 |
| - current_state = state or State.UNKNOWN |
553 |
| - with progress_bar_utils.get_progress_bar(work_flow_step) as progress: |
554 |
| - if max_wait_time > 0 and utils.seconds_since(start_time) >= max_wait_time: |
555 |
| - utils.get_logger().error( |
556 |
| - f"Max wait time ({max_wait_time} seconds) exceeded." |
557 |
| - ) |
558 |
| - while ( |
559 |
| - max_wait_time < 0 or utils.seconds_since(start_time) < max_wait_time |
560 |
| - ) and current_state.name.upper() != final_state: |
561 |
| - if current_state.name.upper() in disallowed_final_states: |
562 |
| - utils.get_logger().info( |
563 |
| - f"Operation failed due to deployment reaching state {current_state.name.upper()}. Use Deployment ID for further steps." |
564 |
| - ) |
565 |
| - break |
566 |
| - |
567 |
| - prev_state = current_state.name |
568 |
| - try: |
569 |
| - model_deployment_payload = json.loads( |
570 |
| - str(self.client.get_model_deployment(model_deployment_id).data) |
571 |
| - ) |
572 |
| - current_state = ( |
573 |
| - State._from_str(model_deployment_payload["lifecycle_state"]) |
574 |
| - if "lifecycle_state" in model_deployment_payload |
575 |
| - else State.UNKNOWN |
576 |
| - ) |
577 |
| - workflow_payload = self.client.list_work_request_logs( |
578 |
| - work_flow_request_id |
579 |
| - ).data |
580 |
| - if isinstance(workflow_payload, list) and len(workflow_payload) > 0: |
581 |
| - if prev_message != workflow_payload[-1].message: |
582 |
| - for _ in range( |
583 |
| - len(workflow_payload) - prev_workflow_stage_len |
584 |
| - ): |
585 |
| - progress.update(workflow_payload[-1].message) |
586 |
| - prev_workflow_stage_len = len(workflow_payload) |
587 |
| - prev_message = workflow_payload[-1].message |
588 |
| - prev_workflow_stage_len = len(workflow_payload) |
589 |
| - if prev_state != current_state.name: |
590 |
| - utils.get_logger().info( |
591 |
| - f"Status Update: {current_state.name} in {utils.seconds_since(start_time)} seconds" |
592 |
| - ) |
593 |
| - except Exception as e: |
594 |
| - # utils.get_logger().warning( |
595 |
| - # "Unable to update deployment status. Details: %s", format( |
596 |
| - # e) |
597 |
| - # ) |
598 |
| - pass |
599 |
| - time.sleep(poll_interval) |
600 |
| - progress.update("Done") |
0 commit comments