Skip to content

Commit f1b7e82

Browse files
committed
Removes pipleine tag validation.
1 parent dd78209 commit f1b7e82

File tree

2 files changed

+27
-27
lines changed

2 files changed

+27
-27
lines changed

ads/aqua/extension/model_handler.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from ads.aqua.extension.base_handler import AquaAPIhandler
1414
from ads.aqua.extension.errors import Errors
1515
from ads.aqua.model import AquaModelApp
16-
from ads.aqua.model.constants import ModelTask
1716
from ads.aqua.model.entities import AquaModelSummary, HFModelSummary
1817
from ads.aqua.ui import ModelFormat
1918

@@ -230,16 +229,17 @@ def post(self, *args, **kwargs):
230229
"Please verify the model's status on the Hugging Face Model Hub or select a different model."
231230
)
232231

233-
# Check pipeline_tag, it should be `text-generation`
234-
if not (
235-
hf_model_info.pipeline_tag
236-
and hf_model_info.pipeline_tag.lower() in ModelTask
237-
):
238-
raise AquaRuntimeError(
239-
f"Unsupported pipeline tag for the chosen model: '{hf_model_info.pipeline_tag}'. "
240-
f"AQUA currently supports the following tasks only: {', '.join(ModelTask.values())}. "
241-
"Please select a model with a compatible pipeline tag."
242-
)
232+
# Commented the validation below to let users to register any model type.
233+
# # Check pipeline_tag, it should be `text-generation`
234+
# if not (
235+
# hf_model_info.pipeline_tag
236+
# and hf_model_info.pipeline_tag.lower() in ModelTask
237+
# ):
238+
# raise AquaRuntimeError(
239+
# f"Unsupported pipeline tag for the chosen model: '{hf_model_info.pipeline_tag}'. "
240+
# f"AQUA currently supports the following tasks only: {', '.join(ModelTask.values())}. "
241+
# "Please select a model with a compatible pipeline tag."
242+
# )
243243

244244
# Check if it is a service/verified model
245245
aqua_model_info: AquaModelSummary = self._find_matching_aqua_model(

tests/unitary/with_extras/aqua/test_model_handler.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -264,22 +264,22 @@ def test_post_negative(self, mock_uuid, mock_format_hf_custom_error_message):
264264
)
265265
get_hf_model_info.cache_clear()
266266

267-
# case 6
268-
self.mock_handler.get_json_body = MagicMock(
269-
return_value={"model_id": "test_model_id"}
270-
)
271-
with patch.object(HfApi, "model_info") as mock_model_info:
272-
mock_model_info.return_value = MagicMock(
273-
disabled=False, id="test_model_id", pipeline_tag="not-text-generation"
274-
)
275-
self.mock_handler.post()
276-
self.mock_handler.finish.assert_called_with(
277-
'{"status": 400, "message": "Something went wrong with your request.", '
278-
'"service_payload": {}, "reason": "Unsupported pipeline tag for the chosen '
279-
"model: 'not-text-generation'. AQUA currently supports the following tasks only: "
280-
f'{", ".join(ModelTask.values())}. '
281-
'Please select a model with a compatible pipeline tag.", "request_id": "###"}'
282-
)
267+
# # case 6 pipeline Tag
268+
# self.mock_handler.get_json_body = MagicMock(
269+
# return_value={"model_id": "test_model_id"}
270+
# )
271+
# with patch.object(HfApi, "model_info") as mock_model_info:
272+
# mock_model_info.return_value = MagicMock(
273+
# disabled=False, id="test_model_id", pipeline_tag="not-text-generation"
274+
# )
275+
# self.mock_handler.post()
276+
# self.mock_handler.finish.assert_called_with(
277+
# '{"status": 400, "message": "Something went wrong with your request.", '
278+
# '"service_payload": {}, "reason": "Unsupported pipeline tag for the chosen '
279+
# "model: 'not-text-generation'. AQUA currently supports the following tasks only: "
280+
# f'{", ".join(ModelTask.values())}. '
281+
# 'Please select a model with a compatible pipeline tag.", "request_id": "###"}'
282+
# )
283283
get_hf_model_info.cache_clear()
284284

285285
@patch("uuid.uuid4")

0 commit comments

Comments
 (0)