Skip to content

Commit 1583829

Browse files
support input tags for model creation
1 parent 3d8f148 commit 1583829

File tree

4 files changed

+117
-10
lines changed

4 files changed

+117
-10
lines changed

ads/aqua/app.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Copyright (c) 2024 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

5+
import json
56
import os
67
from dataclasses import fields
78
from typing import Dict, Union
@@ -135,6 +136,8 @@ def create_model_version_set(
135136
description: str = None,
136137
compartment_id: str = None,
137138
project_id: str = None,
139+
freeform_tags: dict = None,
140+
defined_tags: dict = None,
138141
**kwargs,
139142
) -> tuple:
140143
"""Creates ModelVersionSet from given ID or Name.
@@ -153,7 +156,10 @@ def create_model_version_set(
153156
Project OCID.
154157
tag: (str, optional)
155158
calling tag, can be Tags.AQUA_FINE_TUNING or Tags.AQUA_EVALUATION
156-
159+
freeform_tags: (dict, optional)
160+
Freeform tags for the model version set
161+
defined_tags: (dict, optional)
162+
Defined tags for the model version set
157163
Returns
158164
-------
159165
tuple: (model_version_set_id, model_version_set_name)
@@ -182,13 +188,15 @@ def create_model_version_set(
182188
mvs_freeform_tags = {
183189
tag: tag,
184190
}
191+
mvs_freeform_tags = {**mvs_freeform_tags, **(freeform_tags or {})}
185192
model_version_set = (
186193
ModelVersionSet()
187194
.with_compartment_id(compartment_id)
188195
.with_project_id(project_id)
189196
.with_name(model_version_set_name)
190197
.with_description(description)
191198
.with_freeform_tags(**mvs_freeform_tags)
199+
.with_defined_tags(**(defined_tags or {}))
192200
# TODO: decide what parameters will be needed
193201
# when refactor eval to use this method, we need to pass tag here.
194202
.create(**kwargs)
@@ -340,7 +348,9 @@ def build_cli(self) -> str:
340348
"""
341349
cmd = f"ads aqua {self._command}"
342350
params = [
343-
f"--{field.name} {getattr(self,field.name)}"
351+
f"--{field.name} {json.dumps(getattr(self, field.name))}"
352+
if isinstance(getattr(self, field.name), dict)
353+
else f"--{field.name} {getattr(self, field.name)}"
344354
for field in fields(self.__class__)
345355
if getattr(self, field.name) is not None
346356
]

ads/aqua/model/entities.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,8 @@ class ImportModelDetails(CLIBuilderMixin):
291291
inference_container_uri: Optional[str] = None
292292
allow_patterns: Optional[List[str]] = None
293293
ignore_patterns: Optional[List[str]] = None
294+
freeform_tags: Optional[dict] = None
295+
defined_tags: Optional[dict] = None
294296

295297
def __post_init__(self):
296298
self._command = "model register"

ads/aqua/model/model.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,13 @@ class AquaModelApp(AquaApp):
127127

128128
@telemetry(entry_point="plugin=model&action=create", name="aqua")
129129
def create(
130-
self, model_id: str, project_id: str, compartment_id: str = None, **kwargs
130+
self,
131+
model_id: str,
132+
project_id: str,
133+
compartment_id: str = None,
134+
freeform_tags: Optional[dict] = None,
135+
defined_tags: Optional[dict] = None,
136+
**kwargs,
131137
) -> DataScienceModel:
132138
"""Creates custom aqua model from service model.
133139
@@ -140,7 +146,10 @@ def create(
140146
compartment_id: str
141147
The compartment id for custom model. Defaults to None.
142148
If not provided, compartment id will be fetched from environment variables.
143-
149+
freeform_tags: dict
150+
Freeform tags for the model
151+
defined_tags: dict
152+
Defined tags for the model
144153
Returns
145154
-------
146155
DataScienceModel:
@@ -157,15 +166,25 @@ def create(
157166
)
158167
return service_model
159168

169+
# combine tags
170+
combined_freeform_tags = {
171+
**(service_model.freeform_tags or {}),
172+
**(freeform_tags or {}),
173+
}
174+
combined_defined_tags = {
175+
**(service_model.defined_tags or {}),
176+
**(defined_tags or {}),
177+
}
178+
160179
custom_model = (
161180
DataScienceModel()
162181
.with_compartment_id(target_compartment)
163182
.with_project_id(target_project)
164183
.with_model_file_description(json_dict=service_model.model_file_description)
165184
.with_display_name(service_model.display_name)
166185
.with_description(service_model.description)
167-
.with_freeform_tags(**(service_model.freeform_tags or {}))
168-
.with_defined_tags(**(service_model.defined_tags or {}))
186+
.with_freeform_tags(**combined_freeform_tags)
187+
.with_defined_tags(**combined_defined_tags)
169188
.with_custom_metadata_list(service_model.custom_metadata_list)
170189
.with_defined_metadata_list(service_model.defined_metadata_list)
171190
.with_provenance_metadata(service_model.provenance_metadata)
@@ -414,7 +433,7 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task
414433
except Exception as ex:
415434
raise AquaRuntimeError(
416435
f"The given model already doesn't support finetuning: {ex}"
417-
)
436+
) from ex
418437

419438
custom_metadata_list.remove("modelDescription")
420439
if task:
@@ -766,6 +785,8 @@ def _create_model_catalog_entry(
766785
compartment_id: Optional[str],
767786
project_id: Optional[str],
768787
inference_container_uri: Optional[str],
788+
freeform_tags: Optional[dict] = None,
789+
defined_tags: Optional[dict] = None,
769790
) -> DataScienceModel:
770791
"""Create model by reference from the object storage path
771792
@@ -778,6 +799,8 @@ def _create_model_catalog_entry(
778799
compartment_id (Optional[str]): Compartment Id of the compartment where the model has to be created
779800
project_id (Optional[str]): Project id of the project where the model has to be created
780801
inference_container_uri (Optional[str]): Inference container uri for BYOC
802+
freeform_tags (dict): Freeform tags for the model
803+
defined_tags (dict): Defined tags for the model
781804
782805
Returns:
783806
DataScienceModel: Returns Datascience model instance.
@@ -918,13 +941,16 @@ def _create_model_catalog_entry(
918941
category="Other",
919942
replace=True,
920943
)
944+
# override tags with freeform tags if set
945+
tags = {**tags, **(freeform_tags or {})}
921946
model = (
922947
model.with_custom_metadata_list(metadata)
923948
.with_compartment_id(compartment_id or COMPARTMENT_OCID)
924949
.with_project_id(project_id or PROJECT_OCID)
925950
.with_artifact(os_path)
926951
.with_display_name(model_name)
927952
.with_freeform_tags(**tags)
953+
.with_defined_tags(**(defined_tags or {}))
928954
).create(model_by_reference=True)
929955
logger.debug(model)
930956
return model
@@ -1314,7 +1340,7 @@ def _download_model_from_hf(
13141340
os_path=os_path,
13151341
local_dir=local_dir,
13161342
model_name=model_name,
1317-
exclude_pattern=f"{HF_METADATA_FOLDER}*"
1343+
exclude_pattern=f"{HF_METADATA_FOLDER}*",
13181344
)
13191345

13201346
return model_artifact_path
@@ -1402,6 +1428,8 @@ def register(
14021428
compartment_id=import_model_details.compartment_id,
14031429
project_id=import_model_details.project_id,
14041430
inference_container_uri=import_model_details.inference_container_uri,
1431+
freeform_tags=import_model_details.freeform_tags,
1432+
defined_tags=import_model_details.defined_tags,
14051433
)
14061434
# registered model will always have inference and evaluation container, but
14071435
# fine-tuning container may be not set

tests/unitary/with_extras/aqua/test_model.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -748,13 +748,13 @@ def test_import_verified_model(
748748
local_dir=str(tmpdir),
749749
download_from_hf=True,
750750
allow_patterns=["*.json"],
751-
ignore_patterns=["test.json"]
751+
ignore_patterns=["test.json"],
752752
)
753753
mock_snapshot_download.assert_called_with(
754754
repo_id=model_name,
755755
local_dir=f"{str(tmpdir)}/{model_name}",
756756
allow_patterns=["*.json"],
757-
ignore_patterns=["test.json"]
757+
ignore_patterns=["test.json"],
758758
)
759759
mock_subprocess.assert_called_with(
760760
shlex.split(
@@ -1119,6 +1119,61 @@ def test_import_tei_model_byoc(
11191119
assert model.ready_to_deploy is True
11201120
assert model.ready_to_finetune is False
11211121

1122+
@patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create")
1123+
@patch("ads.model.datascience_model.DataScienceModel.sync")
1124+
@patch("ads.model.datascience_model.DataScienceModel.upload_artifact")
1125+
@patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects")
1126+
@patch.object(HfApi, "model_info")
1127+
@patch("ads.aqua.common.utils.load_config", return_value={})
1128+
def test_import_model_with_input_tags(
1129+
self,
1130+
mock_load_config,
1131+
mock_list_objects,
1132+
mock_upload_artifact,
1133+
mock_sync,
1134+
mock_ocidsc_create,
1135+
mock_get_hf_model_info,
1136+
mock_init_client,
1137+
):
1138+
my_model = "oracle/aqua-1t-mega-model"
1139+
ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True)
1140+
1141+
os_path = "oci://aqua-bkt@aqua-ns/prefix/path"
1142+
ds_freeform_tags = {
1143+
"OCI_AQUA": "active",
1144+
}
1145+
mock_list_objects.return_value = MagicMock(objects=[])
1146+
1147+
reload(ads.aqua.model.model)
1148+
app = AquaModelApp()
1149+
with patch.object(AquaModelApp, "list") as aqua_model_mock_list:
1150+
aqua_model_mock_list.return_value = [
1151+
AquaModelSummary(
1152+
id="test_id1",
1153+
name="organization1/name1",
1154+
organization="organization1",
1155+
)
1156+
]
1157+
model: AquaModel = app.register(
1158+
model=my_model,
1159+
os_path=os_path,
1160+
inference_container="odsc-vllm-or-tgi-container",
1161+
finetuning_container="odsc-llm-fine-tuning",
1162+
download_from_hf=False,
1163+
freeform_tags={"ftag1": "fvalue1", "ftag2": "fvalue2"},
1164+
defined_tags={"dtag1": "dvalue1", "dtag2": "dvalue2"},
1165+
)
1166+
assert model.tags == {
1167+
"aqua_custom_base_model": "true",
1168+
"model_format": "SAFETENSORS",
1169+
"ready_to_fine_tune": "true",
1170+
"dtag1": "dvalue1",
1171+
"dtag2": "dvalue2",
1172+
"ftag1": "fvalue1",
1173+
"ftag2": "fvalue2",
1174+
**ds_freeform_tags,
1175+
}
1176+
11221177
@pytest.mark.parametrize(
11231178
"data, expected_output",
11241179
[
@@ -1163,6 +1218,18 @@ def test_import_tei_model_byoc(
11631218
},
11641219
"ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --inference_container odsc-tei-serving --inference_container_uri <region>.ocir.io/<your_tenancy>/<your_image>",
11651220
),
1221+
(
1222+
{
1223+
"os_path": "oci://aqua-bkt@aqua-ns/path",
1224+
"model": "oracle/oracle-1it",
1225+
"inference_container": "odsc-vllm-serving",
1226+
"freeform_tags": {"ftag1": "fvalue1", "ftag2": "fvalue2"},
1227+
"defined_tags": {"dtag1": "dvalue1", "dtag2": "dvalue2"},
1228+
},
1229+
"ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path "
1230+
"--download_from_hf True --inference_container odsc-vllm-serving --freeform_tags "
1231+
'{"ftag1": "fvalue1", "ftag2": "fvalue2"} --defined_tags {"dtag1": "dvalue1", "dtag2": "dvalue2"}',
1232+
),
11661233
],
11671234
)
11681235
def test_import_cli(self, data, expected_output):

0 commit comments

Comments
 (0)