Skip to content

Commit d49a498

Browse files
authored
Adds timeout option for loading AQUA config. (#878)
2 parents 69222e8 + fac365f commit d49a498

File tree

3 files changed

+109
-5
lines changed

3 files changed

+109
-5
lines changed

ads/aqua/common/utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@
2828
from ads.aqua.constants import *
2929
from ads.aqua.data import AquaResourceIdentifier
3030
from ads.common.auth import default_signer
31+
from ads.common.decorator.threaded import threaded
3132
from ads.common.extended_enum import ExtendedEnumMeta
3233
from ads.common.object_storage_details import ObjectStorageDetails
3334
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
34-
from ads.common.utils import get_console_link, upload_to_os, copy_file
35+
from ads.common.utils import copy_file, get_console_link, upload_to_os
3536
from ads.config import AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, TENANCY_OCID
3637
from ads.model import DataScienceModel, ModelVersionSet
3738

@@ -195,6 +196,7 @@ def read_file(file_path: str, **kwargs) -> str:
195196
return UNKNOWN
196197

197198

199+
@threaded()
198200
def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
199201
artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
200202
if artifact_path.startswith("oci://"):
@@ -540,8 +542,10 @@ def get_container_image(
540542

541543

542544
def fetch_service_compartment() -> Union[str, None]:
543-
"""Loads the compartment mapping json from service bucket. This json file has a service-model-compartment key which
544-
contains a dictionary of namespaces and the compartment OCID of the service models in that namespace.
545+
"""
546+
Loads the compartment mapping json from service bucket.
547+
This json file has a service-model-compartment key which contains a dictionary of namespaces
548+
and the compartment OCID of the service models in that namespace.
545549
"""
546550
config_file_name = (
547551
f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config"
@@ -554,8 +558,8 @@ def fetch_service_compartment() -> Union[str, None]:
554558
)
555559
except Exception as e:
556560
logger.debug(
557-
f"Config file {config_file_name}/{CONTAINER_INDEX} to fetch service compartment OCID could not be found. "
558-
f"\n{str(e)}."
561+
f"Config file {config_file_name}/{CONTAINER_INDEX} to fetch service compartment OCID "
562+
f"could not be found. \n{str(e)}."
559563
)
560564
return
561565
compartment_mapping = config.get(COMPARTMENT_MAPPING_KEY)

ads/common/decorator/threaded.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8; -*-
3+
4+
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
8+
import concurrent.futures
9+
import functools
10+
import logging
11+
from typing import Optional
12+
13+
from git import Optional
14+
15+
from ads.config import THREADED_DEFAULT_TIMEOUT
16+
17+
logger = logging.getLogger(__name__)
18+
19+
# Create a global thread pool with a maximum of 10 threads
20+
thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=10)
21+
22+
23+
class TimeoutError(Exception):
24+
"""
25+
Custom exception to be raised when a function times out.
26+
27+
Attributes
28+
----------
29+
message : str
30+
The error message describing what went wrong.
31+
32+
Parameters
33+
----------
34+
message : str
35+
The error message.
36+
"""
37+
38+
def __init__(
39+
self,
40+
func_name: str,
41+
timeout: int,
42+
message: Optional[str] = "The operation could not be completed in time.",
43+
):
44+
super().__init__(
45+
f"{message} The function '{func_name}' exceeded the timeout of {timeout} seconds."
46+
)
47+
48+
49+
def threaded(timeout: Optional[int] = THREADED_DEFAULT_TIMEOUT):
50+
"""
51+
Decorator to run a function in a separate thread using a global thread pool.
52+
53+
Parameters
54+
----------
55+
timeout (int, optional)
56+
The maximum time in seconds to wait for the function to complete.
57+
If the function does not complete within this time, "timeout" is returned.
58+
59+
Returns
60+
-------
61+
function: The wrapped function that will run in a separate thread with the specified timeout.
62+
"""
63+
64+
def decorator(func):
65+
@functools.wraps(func)
66+
def wrapper(*args, **kwargs):
67+
"""
68+
Wrapper function to submit the decorated function to the thread pool and handle timeout.
69+
70+
Parameters
71+
----------
72+
*args: Positional arguments to pass to the decorated function.
73+
**kwargs: Keyword arguments to pass to the decorated function.
74+
75+
Returns
76+
-------
77+
Any: The result of the decorated function if it completes within the timeout.
78+
79+
Raise
80+
-----
81+
TimeoutError
82+
In case of the function exceeded the timeout.
83+
"""
84+
future = thread_pool.submit(func, *args, **kwargs)
85+
try:
86+
return future.result(timeout=timeout)
87+
except concurrent.futures.TimeoutError as ex:
88+
logger.debug(
89+
f"The function '{func.__name__}' "
90+
f"exceeded the timeout of {timeout} seconds. "
91+
f"{ex}"
92+
)
93+
raise TimeoutError(func.__name__, timeout)
94+
95+
return wrapper
96+
97+
return decorator

ads/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import inspect
99
import os
1010
from typing import Dict, Optional
11+
1112
from ads.common.config import DEFAULT_CONFIG_PATH, DEFAULT_CONFIG_PROFILE, Config, Mode
1213

1314
OCI_ODSC_SERVICE_ENDPOINT = os.environ.get("OCI_ODSC_SERVICE_ENDPOINT")
@@ -85,6 +86,8 @@
8586
AQUA_SERVICE_NAME = "aqua"
8687
DATA_SCIENCE_SERVICE_NAME = "data-science"
8788

89+
THREADED_DEFAULT_TIMEOUT = os.environ.get("THREADED_DEFAULT_TIMEOUT", 5)
90+
8891

8992
def export(
9093
uri: Optional[str] = DEFAULT_CONFIG_PATH,

0 commit comments

Comments
 (0)