Skip to content

Commit b9caf91

Browse files
committed
Adds timeout option for loading AQUA config.
1 parent 69222e8 commit b9caf91

File tree

2 files changed

+104
-5
lines changed

2 files changed

+104
-5
lines changed

ads/aqua/common/utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@
2828
from ads.aqua.constants import *
2929
from ads.aqua.data import AquaResourceIdentifier
3030
from ads.common.auth import default_signer
31+
from ads.common.decorator import threaded
3132
from ads.common.extended_enum import ExtendedEnumMeta
3233
from ads.common.object_storage_details import ObjectStorageDetails
3334
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
34-
from ads.common.utils import get_console_link, upload_to_os, copy_file
35+
from ads.common.utils import copy_file, get_console_link, upload_to_os
3536
from ads.config import AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, TENANCY_OCID
3637
from ads.model import DataScienceModel, ModelVersionSet
3738

@@ -195,6 +196,7 @@ def read_file(file_path: str, **kwargs) -> str:
195196
return UNKNOWN
196197

197198

199+
@threaded(timeout=5)
198200
def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
199201
artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
200202
if artifact_path.startswith("oci://"):
@@ -540,8 +542,10 @@ def get_container_image(
540542

541543

542544
def fetch_service_compartment() -> Union[str, None]:
543-
"""Loads the compartment mapping json from service bucket. This json file has a service-model-compartment key which
544-
contains a dictionary of namespaces and the compartment OCID of the service models in that namespace.
545+
"""
546+
Loads the compartment mapping json from service bucket.
547+
This json file has a service-model-compartment key which contains a dictionary of namespaces
548+
and the compartment OCID of the service models in that namespace.
545549
"""
546550
config_file_name = (
547551
f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config"
@@ -554,8 +558,8 @@ def fetch_service_compartment() -> Union[str, None]:
554558
)
555559
except Exception as e:
556560
logger.debug(
557-
f"Config file {config_file_name}/{CONTAINER_INDEX} to fetch service compartment OCID could not be found. "
558-
f"\n{str(e)}."
561+
f"Config file {config_file_name}/{CONTAINER_INDEX} to fetch service compartment OCID "
562+
f"could not be found. \n{str(e)}."
559563
)
560564
return
561565
compartment_mapping = config.get(COMPARTMENT_MAPPING_KEY)

ads/common/decorator/threaded.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8; -*-
3+
4+
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
8+
import concurrent.futures
9+
import functools
10+
import logging
11+
from typing import Optional
12+
13+
from git import Optional
14+
15+
logger = logging.getLogger(__name__)
16+
17+
# Create a global thread pool with a maximum of 10 threads
18+
thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=10)
19+
20+
21+
class TimeoutError(Exception):
22+
"""
23+
Custom exception to be raised when a function times out.
24+
25+
Attributes
26+
----------
27+
message : str
28+
The error message describing what went wrong.
29+
30+
Parameters
31+
----------
32+
message : str
33+
The error message.
34+
"""
35+
36+
def __init__(
37+
self,
38+
func_name: str,
39+
timeout: int,
40+
message: Optional[str] = "The operation could not be completed in time.",
41+
):
42+
super().__init__(
43+
f"{message} The function '{func_name}' exceeded the timeout of {timeout} seconds."
44+
)
45+
46+
47+
def threaded(timeout=None):
48+
"""
49+
Decorator to run a function in a separate thread using a global thread pool.
50+
51+
Parameters
52+
----------
53+
timeout (int, optional)
54+
The maximum time in seconds to wait for the function to complete.
55+
If the function does not complete within this time, "timeout" is returned.
56+
57+
Returns
58+
-------
59+
function: The wrapped function that will run in a separate thread with the specified timeout.
60+
"""
61+
62+
def decorator(func):
63+
@functools.wraps(func)
64+
def wrapper(*args, **kwargs):
65+
"""
66+
Wrapper function to submit the decorated function to the thread pool and handle timeout.
67+
68+
Parameters
69+
----------
70+
*args: Positional arguments to pass to the decorated function.
71+
**kwargs: Keyword arguments to pass to the decorated function.
72+
73+
Returns
74+
-------
75+
Any: The result of the decorated function if it completes within the timeout.
76+
77+
Raise
78+
-----
79+
TimeoutError
80+
In case of the function exceeded the timeout.
81+
"""
82+
future = thread_pool.submit(func, *args, **kwargs)
83+
try:
84+
return future.result(timeout=timeout)
85+
except concurrent.futures.TimeoutError as ex:
86+
logger.debug(
87+
f"The function '{func.__name__}' "
88+
f"exceeded the timeout of {timeout} seconds. "
89+
f"{ex}"
90+
)
91+
raise TimeoutError(func.__name__, timeout)
92+
93+
return wrapper
94+
95+
return decorator

0 commit comments

Comments
 (0)