Skip to content

Commit b1b4b6d

Browse files
committed
Fixes the validation of acceptable model types for registration in AQUA.
1 parent 38d969c commit b1b4b6d

File tree

2 files changed

+11
-13
lines changed

2 files changed

+11
-13
lines changed

ads/aqua/extension/model_handler.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,8 @@ def _find_matching_aqua_model(model_id: str) -> Optional[AquaModelSummary]:
177177

178178
return None
179179

180-
181-
182180
@handle_exceptions
183-
def get(self,*args, **kwargs):
181+
def get(self, *args, **kwargs):
184182
"""
185183
Finds a list of matching models from hugging face based on query string provided from users.
186184
@@ -194,13 +192,11 @@ def get(self,*args, **kwargs):
194192
Returns the matching model ids string
195193
"""
196194

197-
query=self.get_argument("query",default=None)
195+
query = self.get_argument("query", default=None)
198196
if not query:
199-
raise HTTPError(400,Errors.MISSING_REQUIRED_PARAMETER.format("query"))
200-
models=list_hf_models(query)
201-
return self.finish({"models":models})
202-
203-
197+
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("query"))
198+
models = list_hf_models(query)
199+
return self.finish({"models": models})
204200

205201
@handle_exceptions
206202
def post(self, *args, **kwargs):
@@ -235,9 +231,9 @@ def post(self, *args, **kwargs):
235231
)
236232

237233
# Check pipeline_tag, it should be `text-generation`
238-
if (
239-
not hf_model_info.pipeline_tag
240-
or hf_model_info.pipeline_tag.lower() != ModelTask.TEXT_GENERATION
234+
if not (
235+
hf_model_info.pipeline_tag
236+
and hf_model_info.pipeline_tag.lower() in ModelTask
241237
):
242238
raise AquaRuntimeError(
243239
f"Unsupported pipeline tag for the chosen model: '{hf_model_info.pipeline_tag}'. "

ads/aqua/model/constants.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

@@ -9,6 +8,7 @@
98
109
This module contains constants/enums used in Aqua Model.
1110
"""
11+
1212
from ads.common.extended_enum import ExtendedEnumMeta
1313

1414

@@ -21,6 +21,8 @@ class ModelCustomMetadataFields(str, metaclass=ExtendedEnumMeta):
2121

2222
class ModelTask(str, metaclass=ExtendedEnumMeta):
2323
TEXT_GENERATION = "text-generation"
24+
IMAGE_TEXT_TO_TEXT = "image-text-to-text"
25+
IMAGE_TO_TEXT = "image-to-text"
2426

2527

2628
class FineTuningMetricCategories(str, metaclass=ExtendedEnumMeta):

0 commit comments

Comments
 (0)