|
5 | 5 | # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
6 | 6 |
|
7 | 7 | from functools import wraps
|
8 |
| -import json |
9 | 8 | import time
|
10 | 9 | import logging
|
11 | 10 | from typing import Callable, List
|
|
19 | 18 | from oci.data_science.models import (
|
20 | 19 | CreateModelDeploymentDetails,
|
21 | 20 | UpdateModelDeploymentDetails,
|
| 21 | + WorkRequest |
22 | 22 | )
|
23 | 23 |
|
24 | 24 | DEFAULT_WAIT_TIME = 1200
|
@@ -188,20 +188,13 @@ def activate(
|
188 | 188 | if wait_for_completion:
|
189 | 189 |
|
190 | 190 | self.workflow_req_id = response.headers.get("opc-work-request-id", None)
|
191 |
| - oci_model_deployment_object = self.client.get_model_deployment(self.id).data |
192 |
| - current_state = State._from_str(oci_model_deployment_object.lifecycle_state) |
193 |
| - model_deployment_id = self.id |
194 | 191 |
|
195 | 192 | try:
|
196 |
| - self._wait_for_progress_completion( |
197 |
| - State.ACTIVE.name, |
198 |
| - ACTIVATE_WORKFLOW_STEPS, |
199 |
| - [State.FAILED.name, State.INACTIVE.name], |
200 |
| - self.workflow_req_id, |
201 |
| - current_state, |
202 |
| - model_deployment_id, |
203 |
| - max_wait_time, |
204 |
| - poll_interval, |
| 193 | + self._wait_for_work_request( |
| 194 | + self.workflow_req_id, |
| 195 | + ACTIVATE_WORKFLOW_STEPS, |
| 196 | + max_wait_time, |
| 197 | + poll_interval |
205 | 198 | )
|
206 | 199 | except Exception as e:
|
207 | 200 | logger.error(
|
@@ -243,20 +236,13 @@ def create(
|
243 | 236 | if wait_for_completion:
|
244 | 237 |
|
245 | 238 | self.workflow_req_id = response.headers.get("opc-work-request-id", None)
|
246 |
| - res_payload = json.loads(str(response.data)) |
247 |
| - current_state = State._from_str(res_payload["lifecycle_state"]) |
248 |
| - model_deployment_id = self.id |
249 | 239 |
|
250 | 240 | try:
|
251 |
| - self._wait_for_progress_completion( |
252 |
| - State.ACTIVE.name, |
253 |
| - CREATE_WORKFLOW_STEPS, |
254 |
| - [State.FAILED.name, State.INACTIVE.name], |
255 |
| - self.workflow_req_id, |
256 |
| - current_state, |
257 |
| - model_deployment_id, |
258 |
| - max_wait_time, |
259 |
| - poll_interval, |
| 241 | + self._wait_for_work_request( |
| 242 | + self.workflow_req_id, |
| 243 | + CREATE_WORKFLOW_STEPS, |
| 244 | + max_wait_time, |
| 245 | + poll_interval |
260 | 246 | )
|
261 | 247 | except Exception as e:
|
262 | 248 | logger.error(
|
@@ -301,20 +287,13 @@ def deactivate(
|
301 | 287 | if wait_for_completion:
|
302 | 288 |
|
303 | 289 | self.workflow_req_id = response.headers.get("opc-work-request-id", None)
|
304 |
| - oci_model_deployment_object = self.client.get_model_deployment(self.id).data |
305 |
| - current_state = State._from_str(oci_model_deployment_object.lifecycle_state) |
306 |
| - model_deployment_id = self.id |
307 | 290 |
|
308 | 291 | try:
|
309 |
| - self._wait_for_progress_completion( |
310 |
| - State.INACTIVE.name, |
311 |
| - DEACTIVATE_WORKFLOW_STEPS, |
312 |
| - [State.FAILED.name], |
313 |
| - self.workflow_req_id, |
314 |
| - current_state, |
315 |
| - model_deployment_id, |
316 |
| - max_wait_time, |
317 |
| - poll_interval, |
| 292 | + self._wait_for_work_request( |
| 293 | + self.workflow_req_id, |
| 294 | + DEACTIVATE_WORKFLOW_STEPS, |
| 295 | + max_wait_time, |
| 296 | + poll_interval |
318 | 297 | )
|
319 | 298 | except Exception as e:
|
320 | 299 | logger.error(
|
@@ -359,20 +338,13 @@ def delete(
|
359 | 338 | if wait_for_completion:
|
360 | 339 |
|
361 | 340 | self.workflow_req_id = response.headers.get("opc-work-request-id", None)
|
362 |
| - oci_model_deployment_object = self.client.get_model_deployment(self.id).data |
363 |
| - current_state = State._from_str(oci_model_deployment_object.lifecycle_state) |
364 |
| - model_deployment_id = self.id |
365 | 341 |
|
366 | 342 | try:
|
367 |
| - self._wait_for_progress_completion( |
368 |
| - State.DELETED.name, |
369 |
| - DELETE_WORKFLOW_STEPS, |
370 |
| - [State.FAILED.name, State.INACTIVE.name], |
371 |
| - self.workflow_req_id, |
372 |
| - current_state, |
373 |
| - model_deployment_id, |
374 |
| - max_wait_time, |
375 |
| - poll_interval, |
| 343 | + self._wait_for_work_request( |
| 344 | + self.workflow_req_id, |
| 345 | + DELETE_WORKFLOW_STEPS, |
| 346 | + max_wait_time, |
| 347 | + poll_interval |
376 | 348 | )
|
377 | 349 | except Exception as e:
|
378 | 350 | logger.error(
|
@@ -512,89 +484,76 @@ def from_id(cls, model_deployment_id: str) -> "OCIDataScienceModelDeployment":
|
512 | 484 | """
|
513 | 485 | return super().from_ocid(model_deployment_id)
|
514 | 486 |
|
515 |
| - def _wait_for_progress_completion( |
516 |
| - self, |
517 |
| - final_state: str, |
518 |
| - work_flow_step: int, |
519 |
| - disallowed_final_states: List[str], |
520 |
| - work_flow_request_id: str, |
521 |
| - state: State, |
522 |
| - model_deployment_id: str, |
523 |
| - max_wait_time: int = DEFAULT_WAIT_TIME, |
524 |
| - poll_interval: int = DEFAULT_POLL_INTERVAL, |
525 |
| - ): |
526 |
| - """_wait_for_progress_completion blocks until progress is completed. |
| 487 | + def _wait_for_work_request( |
| 488 | + self, |
| 489 | + work_request_id: str, |
| 490 | + num_steps: int = DELETE_WORKFLOW_STEPS, |
| 491 | + max_wait_time: int = DEFAULT_WAIT_TIME, |
| 492 | + poll_interval: int = DEFAULT_POLL_INTERVAL |
| 493 | + ) -> None: |
| 494 | + """Waits for the work request to be completed. |
527 | 495 |
|
528 | 496 | Parameters
|
529 | 497 | ----------
|
530 |
| - final_state: str |
531 |
| - Final state of model deployment aimed to be reached. |
532 |
| - work_flow_step: int |
533 |
| - Number of work flow step of the request. |
534 |
| - disallowed_final_states: list[str] |
535 |
| - List of disallowed final state to be reached. |
536 |
| - work_flow_request_id: str |
537 |
| - The id of work flow request. |
538 |
| - state: State |
539 |
| - The current state of model deployment. |
540 |
| - model_deployment_id: str |
541 |
| - The ocid of model deployment. |
| 498 | + work_request_id: str |
| 499 | + Work Request OCID. |
| 500 | + num_steps: (int, optional). Defaults to 6. |
| 501 | + Number of steps for the progress indicator. |
542 | 502 | max_wait_time: int
|
543 | 503 | Maximum amount of time to wait in seconds (Defaults to 1200).
|
544 | 504 | Negative implies infinite wait time.
|
545 | 505 | poll_interval: int
|
546 | 506 | Poll interval in seconds (Defaults to 10).
|
| 507 | +
|
| 508 | + Returns |
| 509 | + ------- |
| 510 | + None |
547 | 511 | """
|
| 512 | + STOP_STATE = ( |
| 513 | + WorkRequest.STATUS_SUCCEEDED, |
| 514 | + WorkRequest.STATUS_CANCELED, |
| 515 | + WorkRequest.STATUS_FAILED, |
| 516 | + ) |
| 517 | + work_request_logs = [] |
548 | 518 |
|
| 519 | + i = 0 |
549 | 520 | start_time = time.time()
|
550 |
| - prev_message = "" |
551 |
| - prev_workflow_stage_len = 0 |
552 |
| - current_state = state or State.UNKNOWN |
553 |
| - with progress_bar_utils.get_progress_bar(work_flow_step) as progress: |
554 |
| - if max_wait_time > 0 and utils.seconds_since(start_time) >= max_wait_time: |
555 |
| - utils.get_logger().error( |
| 521 | + with progress_bar_utils.get_progress_bar(num_steps) as progress: |
| 522 | + exceed_max_time = max_wait_time > 0 and utils.seconds_since(start_time) >= max_wait_time |
| 523 | + if exceed_max_time: |
| 524 | + logger.error( |
556 | 525 | f"Max wait time ({max_wait_time} seconds) exceeded."
|
557 | 526 | )
|
558 |
| - while ( |
559 |
| - max_wait_time < 0 or utils.seconds_since(start_time) < max_wait_time |
560 |
| - ) and current_state.name.upper() != final_state: |
561 |
| - if current_state.name.upper() in disallowed_final_states: |
562 |
| - utils.get_logger().info( |
563 |
| - f"Operation failed due to deployment reaching state {current_state.name.upper()}. Use Deployment ID for further steps." |
564 |
| - ) |
565 |
| - break |
566 |
| - |
567 |
| - prev_state = current_state.name |
| 527 | + while not exceed_max_time and (not work_request_logs or len(work_request_logs) < num_steps): |
| 528 | + time.sleep(poll_interval) |
| 529 | + new_work_request_logs = [] |
| 530 | + |
568 | 531 | try:
|
569 |
| - model_deployment_payload = json.loads( |
570 |
| - str(self.client.get_model_deployment(model_deployment_id).data) |
571 |
| - ) |
572 |
| - current_state = ( |
573 |
| - State._from_str(model_deployment_payload["lifecycle_state"]) |
574 |
| - if "lifecycle_state" in model_deployment_payload |
575 |
| - else State.UNKNOWN |
576 |
| - ) |
577 |
| - workflow_payload = self.client.list_work_request_logs( |
578 |
| - work_flow_request_id |
| 532 | + work_request = self.client.get_work_request(work_request_id).data |
| 533 | + work_request_logs = self.client.list_work_request_logs( |
| 534 | + work_request_id |
579 | 535 | ).data
|
580 |
| - if isinstance(workflow_payload, list) and len(workflow_payload) > 0: |
581 |
| - if prev_message != workflow_payload[-1].message: |
582 |
| - for _ in range( |
583 |
| - len(workflow_payload) - prev_workflow_stage_len |
584 |
| - ): |
585 |
| - progress.update(workflow_payload[-1].message) |
586 |
| - prev_workflow_stage_len = len(workflow_payload) |
587 |
| - prev_message = workflow_payload[-1].message |
588 |
| - prev_workflow_stage_len = len(workflow_payload) |
589 |
| - if prev_state != current_state.name: |
590 |
| - utils.get_logger().info( |
591 |
| - f"Status Update: {current_state.name} in {utils.seconds_since(start_time)} seconds" |
592 |
| - ) |
593 |
| - except Exception as e: |
594 |
| - # utils.get_logger().warning( |
595 |
| - # "Unable to update deployment status. Details: %s", format( |
596 |
| - # e) |
597 |
| - # ) |
598 |
| - pass |
599 |
| - time.sleep(poll_interval) |
| 536 | + except Exception as ex: |
| 537 | + logger.warn(ex) |
| 538 | + |
| 539 | + new_work_request_logs = ( |
| 540 | + work_request_logs[i:] if work_request_logs else [] |
| 541 | + ) |
| 542 | + |
| 543 | + for wr_item in new_work_request_logs: |
| 544 | + progress.update(wr_item.message) |
| 545 | + i += 1 |
| 546 | + |
| 547 | + if work_request and work_request.status in STOP_STATE: |
| 548 | + if work_request.status != WorkRequest.STATUS_SUCCEEDED: |
| 549 | + if new_work_request_logs: |
| 550 | + raise Exception(new_work_request_logs[-1].message) |
| 551 | + else: |
| 552 | + raise Exception( |
| 553 | + "Error occurred in attempt to perform the operation. " |
| 554 | + "Check the service logs to get more details. " |
| 555 | + f"{work_request}" |
| 556 | + ) |
| 557 | + else: |
| 558 | + break |
600 | 559 | progress.update("Done")
|
0 commit comments