Skip to content

Commit 83ce1df

Browse files
HF look ahead search
1 parent 833b269 commit 83ce1df

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

ads/aqua/common/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,4 +1064,12 @@ def get_hf_model_info(repo_id: str) -> ModelInfo:
10641064
raise format_hf_custom_error_message(err) from err
10651065

10661066

1067+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
1068+
def list_hf_models(query:str) -> List[str]:
1069+
try:
1070+
models= HfApi().list_models(model_name=query,task="text-generation")
1071+
return [model.id for model in models if model.disabled is None]
1072+
except HfHubHTTPError as err:
1073+
raise format_hf_custom_error_message(err) from err
1074+
10671075

ads/aqua/extension/model_handler.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from ads.aqua.common.decorator import handle_exceptions
1111
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
12-
from ads.aqua.common.utils import get_hf_model_info
12+
from ads.aqua.common.utils import get_hf_model_info, list_hf_models
1313
from ads.aqua.extension.base_handler import AquaAPIhandler
1414
from ads.aqua.extension.errors import Errors
1515
from ads.aqua.model import AquaModelApp
@@ -177,6 +177,31 @@ def _find_matching_aqua_model(model_id: str) -> Optional[AquaModelSummary]:
177177

178178
return None
179179

180+
181+
182+
@handle_exceptions
183+
def get(self,*args, **kwargs):
184+
"""
185+
Finds a list of matching models from hugging face based on query string provided from users.
186+
187+
Parameters
188+
----------
189+
query (str): The Hugging Face model name to search for.
190+
191+
Returns
192+
-------
193+
List[AquaModelSummary]
194+
Returns the matching AquaModelSummary object if found, else None.
195+
"""
196+
197+
query=self.get_argument("query",default=None)
198+
if not query:
199+
raise HTTPError(400,Errors.MISSING_REQUIRED_PARAMETER.format("query"))
200+
models=list_hf_models(query)
201+
return self.finish({"models":models})
202+
203+
204+
180205
@handle_exceptions
181206
def post(self, *args, **kwargs):
182207
"""Handles post request for the HF Models APIs

0 commit comments

Comments
 (0)