Skip to content

Commit d713ada

Browse files
authored
Fixes the validation of acceptable model types for registration in AQUA. (#958)
2 parents 0d9ad9d + 8ee6005 commit d713ada

File tree

3 files changed

+38
-38
lines changed

3 files changed

+38
-38
lines changed

ads/aqua/extension/model_handler.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from ads.aqua.extension.base_handler import AquaAPIhandler
1414
from ads.aqua.extension.errors import Errors
1515
from ads.aqua.model import AquaModelApp
16-
from ads.aqua.model.constants import ModelTask
1716
from ads.aqua.model.entities import AquaModelSummary, HFModelSummary
1817
from ads.aqua.ui import ModelFormat
1918

@@ -177,10 +176,8 @@ def _find_matching_aqua_model(model_id: str) -> Optional[AquaModelSummary]:
177176

178177
return None
179178

180-
181-
182179
@handle_exceptions
183-
def get(self,*args, **kwargs):
180+
def get(self, *args, **kwargs):
184181
"""
185182
Finds a list of matching models from hugging face based on query string provided from users.
186183
@@ -194,13 +191,11 @@ def get(self,*args, **kwargs):
194191
Returns the matching model ids string
195192
"""
196193

197-
query=self.get_argument("query",default=None)
194+
query = self.get_argument("query", default=None)
198195
if not query:
199-
raise HTTPError(400,Errors.MISSING_REQUIRED_PARAMETER.format("query"))
200-
models=list_hf_models(query)
201-
return self.finish({"models":models})
202-
203-
196+
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("query"))
197+
models = list_hf_models(query)
198+
return self.finish({"models": models})
204199

205200
@handle_exceptions
206201
def post(self, *args, **kwargs):
@@ -234,16 +229,17 @@ def post(self, *args, **kwargs):
234229
"Please verify the model's status on the Hugging Face Model Hub or select a different model."
235230
)
236231

237-
# Check pipeline_tag, it should be `text-generation`
238-
if (
239-
not hf_model_info.pipeline_tag
240-
or hf_model_info.pipeline_tag.lower() != ModelTask.TEXT_GENERATION
241-
):
242-
raise AquaRuntimeError(
243-
f"Unsupported pipeline tag for the chosen model: '{hf_model_info.pipeline_tag}'. "
244-
f"AQUA currently supports the following tasks only: {', '.join(ModelTask.values())}. "
245-
"Please select a model with a compatible pipeline tag."
246-
)
232+
# Commented the validation below to let users to register any model type.
233+
# # Check pipeline_tag, it should be `text-generation`
234+
# if not (
235+
# hf_model_info.pipeline_tag
236+
# and hf_model_info.pipeline_tag.lower() in ModelTask
237+
# ):
238+
# raise AquaRuntimeError(
239+
# f"Unsupported pipeline tag for the chosen model: '{hf_model_info.pipeline_tag}'. "
240+
# f"AQUA currently supports the following tasks only: {', '.join(ModelTask.values())}. "
241+
# "Please select a model with a compatible pipeline tag."
242+
# )
247243

248244
# Check if it is a service/verified model
249245
aqua_model_info: AquaModelSummary = self._find_matching_aqua_model(

ads/aqua/model/constants.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

@@ -9,6 +8,7 @@
98
109
This module contains constants/enums used in Aqua Model.
1110
"""
11+
1212
from ads.common.extended_enum import ExtendedEnumMeta
1313

1414

@@ -21,6 +21,8 @@ class ModelCustomMetadataFields(str, metaclass=ExtendedEnumMeta):
2121

2222
class ModelTask(str, metaclass=ExtendedEnumMeta):
2323
TEXT_GENERATION = "text-generation"
24+
IMAGE_TEXT_TO_TEXT = "image-text-to-text"
25+
IMAGE_TO_TEXT = "image-to-text"
2426

2527

2628
class FineTuningMetricCategories(str, metaclass=ExtendedEnumMeta):

tests/unitary/with_extras/aqua/test_model_handler.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
from notebook.base.handlers import IPythonHandler
1313

1414
from ads.aqua.common.errors import AquaRuntimeError
15+
from ads.aqua.common.utils import get_hf_model_info
1516
from ads.aqua.extension.model_handler import (
17+
AquaHuggingFaceHandler,
1618
AquaModelHandler,
1719
AquaModelLicenseHandler,
18-
AquaHuggingFaceHandler,
1920
)
2021
from ads.aqua.model import AquaModelApp
22+
from ads.aqua.model.constants import ModelTask
2123
from ads.aqua.model.entities import AquaModel, AquaModelSummary, HFModelSummary
22-
from ads.aqua.common.utils import get_hf_model_info
2324

2425

2526
class ModelHandlerTestCase(TestCase):
@@ -263,21 +264,22 @@ def test_post_negative(self, mock_uuid, mock_format_hf_custom_error_message):
263264
)
264265
get_hf_model_info.cache_clear()
265266

266-
# case 6
267-
self.mock_handler.get_json_body = MagicMock(
268-
return_value={"model_id": "test_model_id"}
269-
)
270-
with patch.object(HfApi, "model_info") as mock_model_info:
271-
mock_model_info.return_value = MagicMock(
272-
disabled=False, id="test_model_id", pipeline_tag="not-text-generation"
273-
)
274-
self.mock_handler.post()
275-
self.mock_handler.finish.assert_called_with(
276-
'{"status": 400, "message": "Something went wrong with your request.", '
277-
'"service_payload": {}, "reason": "Unsupported pipeline tag for the chosen '
278-
"model: 'not-text-generation'. AQUA currently supports the following tasks only: "
279-
'text-generation. Please select a model with a compatible pipeline tag.", "request_id": "###"}'
280-
)
267+
# # case 6 pipeline Tag
268+
# self.mock_handler.get_json_body = MagicMock(
269+
# return_value={"model_id": "test_model_id"}
270+
# )
271+
# with patch.object(HfApi, "model_info") as mock_model_info:
272+
# mock_model_info.return_value = MagicMock(
273+
# disabled=False, id="test_model_id", pipeline_tag="not-text-generation"
274+
# )
275+
# self.mock_handler.post()
276+
# self.mock_handler.finish.assert_called_with(
277+
# '{"status": 400, "message": "Something went wrong with your request.", '
278+
# '"service_payload": {}, "reason": "Unsupported pipeline tag for the chosen '
279+
# "model: 'not-text-generation'. AQUA currently supports the following tasks only: "
280+
# f'{", ".join(ModelTask.values())}. '
281+
# 'Please select a model with a compatible pipeline tag.", "request_id": "###"}'
282+
# )
281283
get_hf_model_info.cache_clear()
282284

283285
@patch("uuid.uuid4")

0 commit comments

Comments
 (0)