Skip to content

Commit eab7c47

Browse files
[ODSC-65517] Support freeform and defined tags for resource creation in Aqua (#1021)
2 parents 3d8f148 + 612bf71 commit eab7c47

File tree

18 files changed

+422
-159
lines changed

18 files changed

+422
-159
lines changed

ads/aqua/app.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Copyright (c) 2024 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

5+
import json
56
import os
67
from dataclasses import fields
78
from typing import Dict, Union
@@ -135,6 +136,8 @@ def create_model_version_set(
135136
description: str = None,
136137
compartment_id: str = None,
137138
project_id: str = None,
139+
freeform_tags: dict = None,
140+
defined_tags: dict = None,
138141
**kwargs,
139142
) -> tuple:
140143
"""Creates ModelVersionSet from given ID or Name.
@@ -153,7 +156,10 @@ def create_model_version_set(
153156
Project OCID.
154157
tag: (str, optional)
155158
calling tag, can be Tags.AQUA_FINE_TUNING or Tags.AQUA_EVALUATION
156-
159+
freeform_tags: (dict, optional)
160+
Freeform tags for the model version set
161+
defined_tags: (dict, optional)
162+
Defined tags for the model version set
157163
Returns
158164
-------
159165
tuple: (model_version_set_id, model_version_set_name)
@@ -182,13 +188,15 @@ def create_model_version_set(
182188
mvs_freeform_tags = {
183189
tag: tag,
184190
}
191+
mvs_freeform_tags = {**mvs_freeform_tags, **(freeform_tags or {})}
185192
model_version_set = (
186193
ModelVersionSet()
187194
.with_compartment_id(compartment_id)
188195
.with_project_id(project_id)
189196
.with_name(model_version_set_name)
190197
.with_description(description)
191198
.with_freeform_tags(**mvs_freeform_tags)
199+
.with_defined_tags(**(defined_tags or {}))
192200
# TODO: decide what parameters will be needed
193201
# when refactor eval to use this method, we need to pass tag here.
194202
.create(**kwargs)
@@ -340,7 +348,9 @@ def build_cli(self) -> str:
340348
"""
341349
cmd = f"ads aqua {self._command}"
342350
params = [
343-
f"--{field.name} {getattr(self,field.name)}"
351+
f"--{field.name} {json.dumps(getattr(self, field.name))}"
352+
if isinstance(getattr(self, field.name), dict)
353+
else f"--{field.name} {getattr(self, field.name)}"
344354
for field in fields(self.__class__)
345355
if getattr(self, field.name) is not None
346356
]

ads/aqua/evaluation/entities.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ class CreateAquaEvaluationDetails(Serializable):
6464
The metrics for the evaluation.
6565
force_overwrite: (bool, optional). Defaults to `False`.
6666
Whether to force overwrite the existing file in object storage.
67+
freeform_tags: (dict, optional)
68+
Freeform tags for the evaluation model
69+
defined_tags: (dict, optional)
70+
Defined tags for the evaluation model
6771
"""
6872

6973
evaluation_source_id: str
@@ -85,6 +89,8 @@ class CreateAquaEvaluationDetails(Serializable):
8589
log_id: Optional[str] = None
8690
metrics: Optional[List[Dict[str, Any]]] = None
8791
force_overwrite: Optional[bool] = False
92+
freeform_tags: Optional[dict] = None
93+
defined_tags: Optional[dict] = None
8894

8995
class Config:
9096
extra = "ignore"

ads/aqua/evaluation/evaluation.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,10 @@ def create(
297297
evaluation_mvs_freeform_tags = {
298298
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
299299
}
300+
evaluation_mvs_freeform_tags = {
301+
**evaluation_mvs_freeform_tags,
302+
**(create_aqua_evaluation_details.freeform_tags or {}),
303+
}
300304

301305
model_version_set = (
302306
ModelVersionSet()
@@ -307,6 +311,9 @@ def create(
307311
create_aqua_evaluation_details.experiment_description
308312
)
309313
.with_freeform_tags(**evaluation_mvs_freeform_tags)
314+
.with_defined_tags(
315+
**(create_aqua_evaluation_details.defined_tags or {})
316+
)
310317
# TODO: decide what parameters will be needed
311318
.create(**kwargs)
312319
)
@@ -369,6 +376,10 @@ def create(
369376
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
370377
Tags.AQUA_EVALUATION_MODEL_ID: evaluation_model.id,
371378
}
379+
evaluation_job_freeform_tags = {
380+
**evaluation_job_freeform_tags,
381+
**(create_aqua_evaluation_details.freeform_tags or {}),
382+
}
372383

373384
evaluation_job = Job(name=evaluation_model.display_name).with_infrastructure(
374385
DataScienceJob()
@@ -379,6 +390,7 @@ def create(
379390
.with_shape_name(create_aqua_evaluation_details.shape_name)
380391
.with_block_storage_size(create_aqua_evaluation_details.block_storage_size)
381392
.with_freeform_tag(**evaluation_job_freeform_tags)
393+
.with_defined_tag(**(create_aqua_evaluation_details.defined_tags or {}))
382394
)
383395
if (
384396
create_aqua_evaluation_details.memory_in_gbs
@@ -425,6 +437,7 @@ def create(
425437
evaluation_job_run = evaluation_job.run(
426438
name=evaluation_model.display_name,
427439
freeform_tags=evaluation_job_freeform_tags,
440+
defined_tags=(create_aqua_evaluation_details.defined_tags or {}),
428441
wait=False,
429442
)
430443
logger.debug(
@@ -444,13 +457,20 @@ def create(
444457
for metadata in evaluation_model_custom_metadata.to_dict()["data"]
445458
]
446459

460+
evaluation_model_freeform_tags = {
461+
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
462+
**(create_aqua_evaluation_details.freeform_tags or {}),
463+
}
464+
evaluation_model_defined_tags = (
465+
create_aqua_evaluation_details.defined_tags or {}
466+
)
467+
447468
self.ds_client.update_model(
448469
model_id=evaluation_model.id,
449470
update_model_details=UpdateModelDetails(
450471
custom_metadata_list=updated_custom_metadata_list,
451-
freeform_tags={
452-
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
453-
},
472+
freeform_tags=evaluation_model_freeform_tags,
473+
defined_tags=evaluation_model_defined_tags,
454474
),
455475
)
456476

@@ -524,6 +544,8 @@ def create(
524544
"evaluation_job_id": evaluation_job.id,
525545
"evaluation_source": create_aqua_evaluation_details.evaluation_source_id,
526546
"evaluation_experiment_id": experiment_model_version_set_id,
547+
**evaluation_model_freeform_tags,
548+
**evaluation_model_defined_tags,
527549
},
528550
parameters=AquaEvalParams(),
529551
)

ads/aqua/extension/deployment_handler.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def delete(self, model_deployment_id):
5959
return self.finish(AquaDeploymentApp().delete(model_deployment_id))
6060

6161
@handle_exceptions
62-
def put(self, *args, **kwargs):
62+
def put(self, *args, **kwargs): # noqa: ARG002
6363
"""
6464
Handles put request for the activating and deactivating OCI datascience model deployments
6565
Raises
@@ -82,7 +82,7 @@ def put(self, *args, **kwargs):
8282
raise HTTPError(400, f"The request {self.request.path} is invalid.")
8383

8484
@handle_exceptions
85-
def post(self, *args, **kwargs):
85+
def post(self, *args, **kwargs): # noqa: ARG002
8686
"""
8787
Handles post request for the deployment APIs
8888
Raises
@@ -132,6 +132,8 @@ def post(self, *args, **kwargs):
132132
private_endpoint_id = input_data.get("private_endpoint_id")
133133
container_image_uri = input_data.get("container_image_uri")
134134
cmd_var = input_data.get("cmd_var")
135+
freeform_tags = input_data.get("freeform_tags")
136+
defined_tags = input_data.get("defined_tags")
135137

136138
self.finish(
137139
AquaDeploymentApp().create(
@@ -157,6 +159,8 @@ def post(self, *args, **kwargs):
157159
private_endpoint_id=private_endpoint_id,
158160
container_image_uri=container_image_uri,
159161
cmd_var=cmd_var,
162+
freeform_tags=freeform_tags,
163+
defined_tags=defined_tags,
160164
)
161165
)
162166

@@ -196,7 +200,7 @@ def validate_predict_url(endpoint):
196200
return False
197201

198202
@handle_exceptions
199-
def post(self, *args, **kwargs):
203+
def post(self, *args, **kwargs): # noqa: ARG002
200204
"""
201205
Handles inference request for the Active Model Deployments
202206
Raises
@@ -262,7 +266,7 @@ def get(self, model_id):
262266
)
263267

264268
@handle_exceptions
265-
def post(self, *args, **kwargs):
269+
def post(self, *args, **kwargs): # noqa: ARG002
266270
"""Handles post request for the deployment param handler API.
267271
268272
Raises

ads/aqua/extension/model_handler.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def list(self):
9696
)
9797

9898
@handle_exceptions
99-
def post(self, *args, **kwargs):
99+
def post(self, *args, **kwargs): # noqa: ARG002
100100
"""
101101
Handles post request for the registering any Aqua model.
102102
Raises
@@ -131,6 +131,8 @@ def post(self, *args, **kwargs):
131131
inference_container_uri = input_data.get("inference_container_uri")
132132
allow_patterns = input_data.get("allow_patterns")
133133
ignore_patterns = input_data.get("ignore_patterns")
134+
freeform_tags = input_data.get("freeform_tags")
135+
defined_tags = input_data.get("defined_tags")
134136

135137
return self.finish(
136138
AquaModelApp().register(
@@ -145,6 +147,8 @@ def post(self, *args, **kwargs):
145147
inference_container_uri=inference_container_uri,
146148
allow_patterns=allow_patterns,
147149
ignore_patterns=ignore_patterns,
150+
freeform_tags=freeform_tags,
151+
defined_tags=defined_tags,
148152
)
149153
)
150154

@@ -170,11 +174,9 @@ def put(self, id):
170174

171175
enable_finetuning = input_data.get("enable_finetuning")
172176
task = input_data.get("task")
173-
app=AquaModelApp()
177+
app = AquaModelApp()
174178
self.finish(
175-
app.edit_registered_model(
176-
id, inference_container, enable_finetuning, task
177-
)
179+
app.edit_registered_model(id, inference_container, enable_finetuning, task)
178180
)
179181
app.clear_model_details_cache(model_id=id)
180182

@@ -218,7 +220,7 @@ def _find_matching_aqua_model(model_id: str) -> Optional[AquaModelSummary]:
218220
return None
219221

220222
@handle_exceptions
221-
def get(self, *args, **kwargs):
223+
def get(self, *args, **kwargs): # noqa: ARG002
222224
"""
223225
Finds a list of matching models from hugging face based on query string provided from users.
224226
@@ -239,7 +241,7 @@ def get(self, *args, **kwargs):
239241
return self.finish({"models": models})
240242

241243
@handle_exceptions
242-
def post(self, *args, **kwargs):
244+
def post(self, *args, **kwargs): # noqa: ARG002
243245
"""Handles post request for the HF Models APIs
244246
245247
Raises

ads/aqua/finetuning/entities.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ class CreateFineTuningDetails(DataClassSerializable):
8080
The log id for fine tuning job infrastructure.
8181
force_overwrite: (bool, optional). Defaults to `False`.
8282
Whether to force overwrite the existing file in object storage.
83+
freeform_tags: (dict, optional)
84+
Freeform tags for the fine-tuning model
85+
defined_tags: (dict, optional)
86+
Defined tags for the fine-tuning model
8387
"""
8488

8589
ft_source_id: str
@@ -101,3 +105,5 @@ class CreateFineTuningDetails(DataClassSerializable):
101105
log_id: Optional[str] = None
102106
log_group_id: Optional[str] = None
103107
force_overwrite: Optional[bool] = False
108+
freeform_tags: Optional[dict] = None
109+
defined_tags: Optional[dict] = None

0 commit comments

Comments
 (0)