Skip to content

Commit cc6aad6

Browse files
Add source model tag to imported models (#853)
2 parents 884b88f + cfdecc5 commit cc6aad6

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

ads/aqua/extension/model_handler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,10 @@ def post(self, *args, **kwargs):
179179
)
180180

181181
# Check pipeline_tag, it should be `text-generation`
182-
if hf_model_info.pipeline_tag.lower() != ModelTask.TEXT_GENERATION:
182+
if (
183+
not hf_model_info.pipeline_tag
184+
or hf_model_info.pipeline_tag.lower() != ModelTask.TEXT_GENERATION
185+
):
183186
raise AquaRuntimeError(
184187
f"Unsupported pipeline tag for the chosen model: '{hf_model_info.pipeline_tag}'. "
185188
f"AQUA currently supports the following tasks only: {', '.join(ModelTask.values())}. "

ads/aqua/model/model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -592,10 +592,12 @@ def _create_model_catalog_entry(
592592
except Exception:
593593
logger.exception(f"Could not fetch model information for {model_name}")
594594
tags = (
595-
{**shadow_model.freeform_tags, Tags.BASE_MODEL_CUSTOM: "true"}
595+
{**shadow_model.freeform_tags, Tags.AQUA_SERVICE_MODEL_TAG: shadow_model.id}
596596
if shadow_model
597597
else {Tags.AQUA_TAG: "active", Tags.BASE_MODEL_CUSTOM: "true"}
598598
)
599+
tags.update({Tags.BASE_MODEL_CUSTOM: "true"})
600+
599601
# Remove `ready_to_import` tag that might get copied from service model.
600602
tags.pop(Tags.READY_TO_IMPORT, None)
601603
metadata = None
@@ -691,7 +693,7 @@ def _create_model_catalog_entry(
691693

692694
def register(
693695
self, import_model_details: ImportModelDetails = None, **kwargs
694-
) -> str:
696+
) -> DataScienceModel:
695697
"""Loads the model from huggingface and registers as Model in Data Science Model catalog
696698
Note: For the models that require user token, use `huggingface-cli login` to setup the token
697699
The inference container and finetuning container could be of type Service Manged Container(SMC) or custom. If it is custom, full container URI is expected. If it of type SMC, only the container family name is expected.

tests/unitary/with_extras/aqua/test_model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,8 @@ def test_create_model(self, mock_from_id, mock_validate, mock_create):
249249
@patch("ads.aqua.model.model.read_file")
250250
@patch.object(DataScienceModel, "from_id")
251251
@patch(
252-
"ads.aqua.model.model.get_artifact_path", return_value="oci://bucket@namespace/prefix"
252+
"ads.aqua.model.model.get_artifact_path",
253+
return_value="oci://bucket@namespace/prefix",
253254
)
254255
def test_get_foundation_models(
255256
self,
@@ -356,7 +357,8 @@ def test_get_foundation_models(
356357
@patch("ads.aqua.model.model.read_file")
357358
@patch.object(DataScienceModel, "from_id")
358359
@patch(
359-
"ads.aqua.model.model.get_artifact_path", return_value="oci://bucket@namespace/prefix"
360+
"ads.aqua.model.model.get_artifact_path",
361+
return_value="oci://bucket@namespace/prefix",
360362
)
361363
def test_get_model_fine_tuned(
362364
self, mock_get_artifact_path, mock_from_id, mock_read_file, mock_query_resource
@@ -563,6 +565,7 @@ def test_import_shadow_model(
563565
)
564566
ds_model.with_custom_metadata_list(custom_metadata_list)
565567
ds_model.set_spec(ds_model.CONST_MODEL_FILE_DESCRIPTION, {})
568+
ds_model.dsc_model = MagicMock(id="test_model_id")
566569
DataScienceModel.from_id = MagicMock(return_value=ds_model)
567570
reload(ads.aqua.model.model)
568571
app = AquaModelApp()
@@ -587,6 +590,7 @@ def test_import_shadow_model(
587590
) # The imported model should not have this tag
588591
assert model.freeform_tags == {
589592
"aqua_custom_base_model": "true",
593+
"aqua_service_model": "test_model_id",
590594
**ds_freeform_tags,
591595
}
592596
expected_metadata = [

0 commit comments

Comments
 (0)