Skip to content

Commit 882a215

Browse files
added compatibility check
1 parent a76b698 commit 882a215

File tree

4 files changed

+23
-10
lines changed

4 files changed

+23
-10
lines changed

ads/aqua/common/utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
)
5959
from ads.aqua.data import AquaResourceIdentifier
6060
from ads.common.auth import AuthState, default_signer
61+
from ads.common.decorator.threaded import threaded
6162
from ads.common.extended_enum import ExtendedEnumMeta
6263
from ads.common.object_storage_details import ObjectStorageDetails
6364
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
@@ -225,6 +226,7 @@ def read_file(file_path: str, **kwargs) -> str:
225226
return UNKNOWN
226227

227228

229+
@threaded()
228230
def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
229231
artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
230232
signer = default_signer() if artifact_path.startswith("oci://") else {}
@@ -1065,11 +1067,15 @@ def get_hf_model_info(repo_id: str) -> ModelInfo:
10651067

10661068

10671069
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
1068-
def list_hf_models(query:str) -> List[str]:
1070+
def list_hf_models(query: str) -> List[str]:
10691071
try:
1070-
models= HfApi().list_models(model_name=query,task="text-generation",sort="downloads",direction=-1,limit=20)
1072+
models = HfApi().list_models(
1073+
model_name=query,
1074+
task="text-generation",
1075+
sort="downloads",
1076+
direction=-1,
1077+
limit=20,
1078+
)
10711079
return [model.id for model in models if model.disabled is None]
10721080
except HfHubHTTPError as err:
10731081
raise format_hf_custom_error_message(err) from err
1074-
1075-

ads/aqua/extension/common_handler.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,15 @@
1111
from huggingface_hub.utils import LocalTokenNotFoundError
1212
from tornado.web import HTTPError
1313

14-
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
1514
from ads.aqua.common.decorator import handle_exceptions
1615
from ads.aqua.common.errors import AquaResourceAccessError, AquaRuntimeError
1716
from ads.aqua.common.utils import (
18-
fetch_service_compartment,
1917
get_huggingface_login_timeout,
2018
known_realm,
2119
)
2220
from ads.aqua.extension.base_handler import AquaAPIhandler
2321
from ads.aqua.extension.errors import Errors
22+
from ads.aqua.extension.utils import ui_compatability_check
2423

2524

2625
class ADSVersionHandler(AquaAPIhandler):
@@ -51,7 +50,7 @@ def get(self):
5150
AquaResourceAccessError: raised when aqua is not accessible in the given session/region.
5251
5352
"""
54-
if ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment():
53+
if ui_compatability_check():
5554
return self.finish({"status": "ok"})
5655
elif known_realm():
5756
return self.finish({"status": "compatible"})

ads/aqua/extension/common_ws_msg_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from importlib import metadata
88
from typing import List, Union
99

10-
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, fetch_service_compartment
1110
from ads.aqua.common.decorator import handle_exceptions
1211
from ads.aqua.common.errors import AquaResourceAccessError
1312
from ads.aqua.common.utils import known_realm
@@ -17,6 +16,7 @@
1716
CompatibilityCheckResponse,
1817
RequestResponseType,
1918
)
19+
from ads.aqua.extension.utils import ui_compatability_check
2020

2121

2222
class AquaCommonWsMsgHandler(AquaWSMsgHandler):
@@ -39,7 +39,7 @@ def process(self) -> Union[AdsVersionResponse, CompatibilityCheckResponse]:
3939
)
4040
return response
4141
if request.get("kind") == "CompatibilityCheck":
42-
if ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment():
42+
if ui_compatability_check():
4343
return CompatibilityCheckResponse(
4444
message_id=request.get("message_id"),
4545
kind=RequestResponseType.CompatibilityCheck,

ads/aqua/extension/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54
from dataclasses import fields
5+
from datetime import datetime, timedelta
66
from typing import Dict, Optional
77

8+
from cachetools import TTLCache, cached
89
from tornado.web import HTTPError
910

11+
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
12+
from ads.aqua.common.utils import fetch_service_compartment
1013
from ads.aqua.extension.errors import Errors
1114

1215

@@ -21,3 +24,8 @@ def validate_function_parameters(data_class, input_data: Dict):
2124
raise HTTPError(
2225
400, Errors.MISSING_REQUIRED_PARAMETER.format(required_parameter)
2326
)
27+
28+
29+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=1), timer=datetime.now))
30+
def ui_compatability_check():
31+
return ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment()

0 commit comments

Comments
 (0)