Skip to content

Commit 2174b36

Browse files
HF look ahead search
1 parent 6b31d0f commit 2174b36

File tree

3 files changed

+39
-1
lines changed

3 files changed

+39
-1
lines changed

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,7 @@ clean:
1111
@find ./ -name 'Thumbs.db' -exec rm -f {} \;
1212
@find ./ -name '*~' -exec rm -f {} \;
1313
@find ./ -name '.DS_Store' -exec rm -f {} \;
14+
test:
15+
pip install -e .
16+
jupyter server extension enable --py ads.aqua.extension
17+
jupyter lab --NotebookApp.disable_check_xsrf=True --no-browser

ads/aqua/common/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,15 @@
5757
VLLM_INFERENCE_RESTRICTED_PARAMS,
5858
)
5959
from ads.aqua.data import AquaResourceIdentifier
60+
from ads.aqua.model.constants import ModelTask
6061
from ads.common.auth import AuthState, default_signer
6162
from ads.common.extended_enum import ExtendedEnumMeta
6263
from ads.common.object_storage_details import ObjectStorageDetails
6364
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
6465
from ads.common.utils import copy_file, get_console_link, upload_to_os
6566
from ads.config import AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, TENANCY_OCID
6667
from ads.model import DataScienceModel, ModelVersionSet
68+
from tests.unitary.with_extras.model.score import model_name
6769

6870
logger = logging.getLogger("ads.aqua")
6971

@@ -1062,3 +1064,12 @@ def get_hf_model_info(repo_id: str) -> ModelInfo:
10621064
return HfApi().model_info(repo_id=repo_id)
10631065
except HfHubHTTPError as err:
10641066
raise format_hf_custom_error_message(err) from err
1067+
1068+
def list_hf_models(query:str) -> List[str]:
1069+
try:
1070+
models= HfApi().list_models(model_name=query,task=ModelTask.TEXT_GENERATION)
1071+
return [model.id for model in models if model.disabled is None]
1072+
except HfHubHTTPError as err:
1073+
raise format_hf_custom_error_message(err) from err
1074+
1075+

ads/aqua/extension/model_handler.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from ads.aqua.common.decorator import handle_exceptions
1111
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
12-
from ads.aqua.common.utils import get_hf_model_info
12+
from ads.aqua.common.utils import get_hf_model_info, list_hf_models
1313
from ads.aqua.extension.base_handler import AquaAPIhandler
1414
from ads.aqua.extension.errors import Errors
1515
from ads.aqua.model import AquaModelApp
@@ -177,6 +177,29 @@ def _find_matching_aqua_model(model_id: str) -> Optional[AquaModelSummary]:
177177

178178
return None
179179

180+
@handle_exceptions
181+
def get(self):
182+
"""
183+
Finds a list of matching models from hugging face based on query string provided from users.
184+
185+
Parameters
186+
----------
187+
query (str): The Hugging Face model name to search for.
188+
189+
Returns
190+
-------
191+
List[AquaModelSummary]
192+
Returns the matching AquaModelSummary object if found, else None.
193+
"""
194+
195+
query=self.get_argument("query",default=None)
196+
if not query:
197+
raise HTTPError(400,Errors.MISSING_REQUIRED_PARAMETER.format("query"))
198+
models=list_hf_models(query)
199+
return self.finish({"models":models})
200+
201+
202+
180203
@handle_exceptions
181204
def post(self, *args, **kwargs):
182205
"""Handles post request for the HF Models APIs

0 commit comments

Comments
 (0)