Skip to content

Commit e0f0f3b

Browse files
Adding task for unverified model import
1 parent f674d4d commit e0f0f3b

File tree

4 files changed

+19
-1
lines changed

4 files changed

+19
-1
lines changed

ads/aqua/config/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,11 @@ def get_evaluation_service_config(
2929
.get(ContainerSpec.CONTAINER_SPEC, {})
3030
.get(container, {})
3131
)
32+
33+
def get_valid_tasks():
34+
return [
35+
"text_generation",
36+
"code_synthesis",
37+
"image_text_to_text",
38+
"feature_extraction",
39+
]

ads/aqua/extension/model_handler.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
get_hf_model_info,
1414
list_hf_models,
1515
)
16+
from ads.aqua.config.config import get_valid_tasks
1617
from ads.aqua.extension.base_handler import AquaAPIhandler
1718
from ads.aqua.extension.errors import Errors
1819
from ads.aqua.model import AquaModelApp
@@ -119,6 +120,9 @@ def post(self, *args, **kwargs):
119120
os_path = input_data.get("os_path")
120121
if not os_path:
121122
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("os_path"))
123+
task = input_data.get("task")
124+
if task not in get_valid_tasks():
125+
raise HTTPError(400, Errors.INVALID_VALUE_OF_PARAMETER.format("task"))
122126

123127
inference_container = input_data.get("inference_container")
124128
finetuning_container = input_data.get("finetuning_container")
@@ -128,7 +132,6 @@ def post(self, *args, **kwargs):
128132
download_from_hf = (
129133
str(input_data.get("download_from_hf", "false")).lower() == "true"
130134
)
131-
132135
return self.finish(
133136
AquaModelApp().register(
134137
model=model,
@@ -139,6 +142,7 @@ def post(self, *args, **kwargs):
139142
compartment_id=compartment_id,
140143
project_id=project_id,
141144
model_file=model_file,
145+
task=task,
142146
)
143147
)
144148

@@ -164,6 +168,8 @@ def put(self, id):
164168

165169
enable_finetuning = input_data.get("enable_finetuning")
166170
task = input_data.get("task")
171+
if task not in get_valid_tasks():
172+
raise HTTPError(400, Errors.INVALID_VALUE_OF_PARAMETER.format("task"))
167173
return self.finish(
168174
AquaModelApp().edit_registered_model(
169175
id, inference_container, enable_finetuning, task

ads/aqua/model/entities.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ class ImportModelDetails(CLIBuilderMixin):
284284
local_dir: Optional[str] = None
285285
inference_container: Optional[str] = None
286286
finetuning_container: Optional[str] = None
287+
task: Optional[str] = None
287288
compartment_id: Optional[str] = None
288289
project_id: Optional[str] = None
289290
model_file: Optional[str] = None

ads/aqua/model/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,7 @@ def _create_model_catalog_entry(
735735
model_name: str,
736736
inference_container: str,
737737
finetuning_container: str,
738+
task: Optional[str],
738739
verified_model: DataScienceModel,
739740
validation_result: ModelValidationResult,
740741
compartment_id: Optional[str],
@@ -790,6 +791,7 @@ def _create_model_catalog_entry(
790791
json_dict=verified_model.model_file_description
791792
)
792793
else:
794+
tags.update({Tags.TASK: task})
793795
metadata = ModelCustomMetadata()
794796
if not inference_container:
795797
raise AquaRuntimeError(
@@ -1336,6 +1338,7 @@ def register(
13361338
model_name=model_name,
13371339
inference_container=import_model_details.inference_container,
13381340
finetuning_container=import_model_details.finetuning_container,
1341+
task=import_model_details.task,
13391342
verified_model=verified_model,
13401343
validation_result=validation_result,
13411344
compartment_id=import_model_details.compartment_id,

0 commit comments

Comments
 (0)