Skip to content

Commit 5b4d2d3

Browse files
Updated compatibility check for aqua (#952)
2 parents a76b698 + 2a27ef6 commit 5b4d2d3

File tree

6 files changed

+47
-17
lines changed

6 files changed

+47
-17
lines changed

ads/aqua/common/utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
)
5959
from ads.aqua.data import AquaResourceIdentifier
6060
from ads.common.auth import AuthState, default_signer
61+
from ads.common.decorator.threaded import threaded
6162
from ads.common.extended_enum import ExtendedEnumMeta
6263
from ads.common.object_storage_details import ObjectStorageDetails
6364
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
@@ -225,6 +226,7 @@ def read_file(file_path: str, **kwargs) -> str:
225226
return UNKNOWN
226227

227228

229+
@threaded()
228230
def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
229231
artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
230232
signer = default_signer() if artifact_path.startswith("oci://") else {}
@@ -1065,11 +1067,15 @@ def get_hf_model_info(repo_id: str) -> ModelInfo:
10651067

10661068

10671069
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
1068-
def list_hf_models(query:str) -> List[str]:
1070+
def list_hf_models(query: str) -> List[str]:
10691071
try:
1070-
models= HfApi().list_models(model_name=query,task="text-generation",sort="downloads",direction=-1,limit=20)
1072+
models = HfApi().list_models(
1073+
model_name=query,
1074+
task="text-generation",
1075+
sort="downloads",
1076+
direction=-1,
1077+
limit=20,
1078+
)
10711079
return [model.id for model in models if model.disabled is None]
10721080
except HfHubHTTPError as err:
10731081
raise format_hf_custom_error_message(err) from err
1074-
1075-

ads/aqua/extension/common_handler.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,15 @@
1111
from huggingface_hub.utils import LocalTokenNotFoundError
1212
from tornado.web import HTTPError
1313

14-
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
1514
from ads.aqua.common.decorator import handle_exceptions
1615
from ads.aqua.common.errors import AquaResourceAccessError, AquaRuntimeError
1716
from ads.aqua.common.utils import (
18-
fetch_service_compartment,
1917
get_huggingface_login_timeout,
2018
known_realm,
2119
)
2220
from ads.aqua.extension.base_handler import AquaAPIhandler
2321
from ads.aqua.extension.errors import Errors
22+
from ads.aqua.extension.utils import ui_compatability_check
2423

2524

2625
class ADSVersionHandler(AquaAPIhandler):
@@ -51,7 +50,7 @@ def get(self):
5150
AquaResourceAccessError: raised when aqua is not accessible in the given session/region.
5251
5352
"""
54-
if ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment():
53+
if ui_compatability_check():
5554
return self.finish({"status": "ok"})
5655
elif known_realm():
5756
return self.finish({"status": "compatible"})

ads/aqua/extension/common_ws_msg_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from importlib import metadata
88
from typing import List, Union
99

10-
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, fetch_service_compartment
1110
from ads.aqua.common.decorator import handle_exceptions
1211
from ads.aqua.common.errors import AquaResourceAccessError
1312
from ads.aqua.common.utils import known_realm
@@ -17,6 +16,7 @@
1716
CompatibilityCheckResponse,
1817
RequestResponseType,
1918
)
19+
from ads.aqua.extension.utils import ui_compatability_check
2020

2121

2222
class AquaCommonWsMsgHandler(AquaWSMsgHandler):
@@ -39,7 +39,7 @@ def process(self) -> Union[AdsVersionResponse, CompatibilityCheckResponse]:
3939
)
4040
return response
4141
if request.get("kind") == "CompatibilityCheck":
42-
if ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment():
42+
if ui_compatability_check():
4343
return CompatibilityCheckResponse(
4444
message_id=request.get("message_id"),
4545
kind=RequestResponseType.CompatibilityCheck,

ads/aqua/extension/utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54
from dataclasses import fields
5+
from datetime import datetime, timedelta
66
from typing import Dict, Optional
77

8+
from cachetools import TTLCache, cached
89
from tornado.web import HTTPError
910

11+
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
12+
from ads.aqua.common.utils import fetch_service_compartment
1013
from ads.aqua.extension.errors import Errors
1114

1215

@@ -21,3 +24,11 @@ def validate_function_parameters(data_class, input_data: Dict):
2124
raise HTTPError(
2225
400, Errors.MISSING_REQUIRED_PARAMETER.format(required_parameter)
2326
)
27+
28+
29+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=1), timer=datetime.now))
30+
def ui_compatability_check():
31+
"""This method caches the service compartment OCID details that is set by either the environment variable or if
32+
fetched from the configuration. The cached result is returned when multiple calls are made in quick succession
33+
from the UI to avoid multiple config file loads."""
34+
return ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment()

tests/unitary/with_extras/aqua/test_common_handler.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import ads.config
1616
from ads.aqua.constants import AQUA_GA_LIST
1717
from ads.aqua.extension.common_handler import CompatibilityCheckHandler
18+
from ads.aqua.extension.utils import ui_compatability_check
1819

1920

2021
class TestDataset:
@@ -28,6 +29,9 @@ def setUp(self, ipython_init_mock) -> None:
2829
self.common_handler = CompatibilityCheckHandler(MagicMock(), MagicMock())
2930
self.common_handler.request = MagicMock()
3031

32+
def tearDown(self) -> None:
33+
ui_compatability_check.cache_clear()
34+
3135
def test_get_ok(self):
3236
"""Test to check if ok is returned when ODSC_MODEL_COMPARTMENT_OCID is set."""
3337
with patch.dict(
@@ -36,15 +40,22 @@ def test_get_ok(self):
3640
):
3741
reload(ads.config)
3842
reload(ads.aqua)
43+
reload(ads.aqua.extension.utils)
3944
reload(ads.aqua.extension.common_handler)
4045

4146
with patch(
4247
"ads.aqua.extension.base_handler.AquaAPIhandler.finish"
4348
) as mock_finish:
44-
mock_finish.side_effect = lambda x: x
45-
self.common_handler.request.path = "aqua/hello"
46-
result = self.common_handler.get()
47-
assert result["status"] == "ok"
49+
with patch(
50+
"ads.aqua.extension.utils.fetch_service_compartment"
51+
) as mock_fetch_service_compartment:
52+
mock_fetch_service_compartment.return_value = (
53+
TestDataset.SERVICE_COMPARTMENT_ID
54+
)
55+
mock_finish.side_effect = lambda x: x
56+
self.common_handler.request.path = "aqua/hello"
57+
result = self.common_handler.get()
58+
assert result["status"] == "ok"
4859

4960
def test_get_compatible_status(self):
5061
"""Test to check if compatible is returned when ODSC_MODEL_COMPARTMENT_OCID is not set
@@ -55,12 +66,13 @@ def test_get_compatible_status(self):
5566
):
5667
reload(ads.config)
5768
reload(ads.aqua)
69+
reload(ads.aqua.extension.utils)
5870
reload(ads.aqua.extension.common_handler)
5971
with patch(
6072
"ads.aqua.extension.base_handler.AquaAPIhandler.finish"
6173
) as mock_finish:
6274
with patch(
63-
"ads.aqua.extension.common_handler.fetch_service_compartment"
75+
"ads.aqua.extension.utils.fetch_service_compartment"
6476
) as mock_fetch_service_compartment:
6577
mock_fetch_service_compartment.return_value = None
6678
mock_finish.side_effect = lambda x: x
@@ -77,12 +89,13 @@ def test_raise_not_compatible_error(self):
7789
):
7890
reload(ads.config)
7991
reload(ads.aqua)
92+
reload(ads.aqua.extension.utils)
8093
reload(ads.aqua.extension.common_handler)
8194
with patch(
8295
"ads.aqua.extension.base_handler.AquaAPIhandler.finish"
8396
) as mock_finish:
8497
with patch(
85-
"ads.aqua.extension.common_handler.fetch_service_compartment"
98+
"ads.aqua.extension.utils.fetch_service_compartment"
8699
) as mock_fetch_service_compartment:
87100
mock_fetch_service_compartment.return_value = None
88101
mock_finish.side_effect = lambda x: x

tests/unitary/with_extras/aqua/test_handlers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from notebook.base.handlers import APIHandler, IPythonHandler
1414
from oci.exceptions import ServiceError
1515
from parameterized import parameterized
16-
from tornado.httpserver import HTTPRequest
1716
from tornado.httputil import HTTPServerRequest
1817
from tornado.web import Application, HTTPError
1918

@@ -191,6 +190,7 @@ def setUpClass(cls):
191190

192191
reload(ads.config)
193192
reload(ads.aqua)
193+
reload(ads.aqua.extension.utils)
194194
reload(ads.aqua.extension.common_handler)
195195

196196
@classmethod
@@ -200,6 +200,7 @@ def tearDownClass(cls):
200200

201201
reload(ads.config)
202202
reload(ads.aqua)
203+
reload(ads.aqua.extension.utils)
203204
reload(ads.aqua.extension.common_handler)
204205

205206
@parameterized.expand(

0 commit comments

Comments
 (0)