Skip to content

Commit 13053b1

Browse files
support input tags for finetuning
1 parent d710034 commit 13053b1

File tree

3 files changed

+71
-38
lines changed

3 files changed

+71
-38
lines changed

ads/aqua/finetuning/entities.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ class CreateFineTuningDetails(DataClassSerializable):
8080
The log id for fine tuning job infrastructure.
8181
force_overwrite: (bool, optional). Defaults to `False`.
8282
Whether to force overwrite the existing file in object storage.
83+
freeform_tags: (dict, optional)
84+
Freeform tags for the fine-tuning model
85+
defined_tags: (dict, optional)
86+
Defined tags for the fine-tuning model
8387
"""
8488

8589
ft_source_id: str
@@ -101,3 +105,5 @@ class CreateFineTuningDetails(DataClassSerializable):
101105
log_id: Optional[str] = None
102106
log_group_id: Optional[str] = None
103107
force_overwrite: Optional[bool] = False
108+
freeform_tags: Optional[dict] = None
109+
defined_tags: Optional[dict] = None

ads/aqua/finetuning/finetuning.py

Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@
3535
ENV_AQUA_FINE_TUNING_CONTAINER,
3636
FineTuneCustomMetadata,
3737
)
38-
from ads.aqua.finetuning.entities import *
38+
from ads.aqua.finetuning.entities import (
39+
AquaFineTuningParams,
40+
AquaFineTuningSummary,
41+
CreateFineTuningDetails,
42+
)
3943
from ads.common.auth import default_signer
4044
from ads.common.object_storage_details import ObjectStorageDetails
4145
from ads.common.utils import get_console_link
@@ -100,14 +104,14 @@ def create(
100104
if not create_fine_tuning_details:
101105
try:
102106
create_fine_tuning_details = CreateFineTuningDetails(**kwargs)
103-
except:
107+
except Exception as ex:
104108
allowed_create_fine_tuning_details = ", ".join(
105109
field.name for field in fields(CreateFineTuningDetails)
106110
).rstrip()
107111
raise AquaValueError(
108112
"Invalid create fine tuning parameters. Allowable parameters are: "
109113
f"{allowed_create_fine_tuning_details}."
110-
)
114+
) from ex
111115

112116
source = self.get_source(create_fine_tuning_details.ft_source_id)
113117

@@ -148,28 +152,27 @@ def create(
148152
"Specify the subnet id via API or environment variable AQUA_JOB_SUBNET_ID."
149153
)
150154

151-
if create_fine_tuning_details.replica > DEFAULT_FT_REPLICA:
152-
if not (
153-
create_fine_tuning_details.log_id
154-
and create_fine_tuning_details.log_group_id
155-
):
156-
raise AquaValueError(
157-
f"Logging is required for fine tuning if replica is larger than {DEFAULT_FT_REPLICA}."
158-
)
155+
if create_fine_tuning_details.replica > DEFAULT_FT_REPLICA and not (
156+
create_fine_tuning_details.log_id
157+
and create_fine_tuning_details.log_group_id
158+
):
159+
raise AquaValueError(
160+
f"Logging is required for fine tuning if replica is larger than {DEFAULT_FT_REPLICA}."
161+
)
159162

160163
ft_parameters = None
161164
try:
162165
ft_parameters = AquaFineTuningParams(
163166
**create_fine_tuning_details.ft_parameters,
164167
)
165-
except:
168+
except Exception as ex:
166169
allowed_fine_tuning_parameters = ", ".join(
167170
field.name for field in fields(AquaFineTuningParams)
168171
).rstrip()
169172
raise AquaValueError(
170173
"Invalid fine tuning parameters. Fine tuning parameters should "
171174
f"be a dictionary with keys: {allowed_fine_tuning_parameters}."
172-
)
175+
) from ex
173176

174177
experiment_model_version_set_id = create_fine_tuning_details.experiment_id
175178
experiment_model_version_set_name = create_fine_tuning_details.experiment_name
@@ -197,11 +200,11 @@ def create(
197200
auth=default_signer(),
198201
force_overwrite=create_fine_tuning_details.force_overwrite,
199202
)
200-
except FileExistsError:
203+
except FileExistsError as fe:
201204
raise AquaFileExistsError(
202205
f"Dataset {dataset_file} already exists in {create_fine_tuning_details.report_path}. "
203206
"Please use a new dataset file name, report path or set `force_overwrite` as True."
204-
)
207+
) from fe
205208
logger.debug(
206209
f"Uploaded local file {ft_dataset_path} to object storage {dst_uri}."
207210
)
@@ -222,6 +225,8 @@ def create(
222225
description=create_fine_tuning_details.experiment_description,
223226
compartment_id=target_compartment,
224227
project_id=target_project,
228+
freeform_tags=create_fine_tuning_details.freeform_tags,
229+
defined_tags=create_fine_tuning_details.defined_tags,
225230
)
226231

227232
ft_model_custom_metadata = ModelCustomMetadata()
@@ -273,6 +278,10 @@ def create(
273278
Tags.AQUA_TAG: UNKNOWN,
274279
Tags.AQUA_FINE_TUNED_MODEL_TAG: f"{source.id}#{source.display_name}",
275280
}
281+
ft_job_freeform_tags = {
282+
**ft_job_freeform_tags,
283+
**(create_fine_tuning_details.freeform_tags or {}),
284+
}
276285

277286
ft_job = Job(name=ft_model.display_name).with_infrastructure(
278287
DataScienceJob()
@@ -286,6 +295,7 @@ def create(
286295
or DEFAULT_FT_BLOCK_STORAGE_SIZE
287296
)
288297
.with_freeform_tag(**ft_job_freeform_tags)
298+
.with_defined_tag(**(create_fine_tuning_details.defined_tags or {}))
289299
)
290300

291301
if not subnet_id:
@@ -353,6 +363,7 @@ def create(
353363
ft_job_run = ft_job.run(
354364
name=ft_model.display_name,
355365
freeform_tags=ft_job_freeform_tags,
366+
defined_tags=create_fine_tuning_details.defined_tags or {},
356367
wait=False,
357368
)
358369
logger.debug(
@@ -372,22 +383,25 @@ def create(
372383
for metadata in ft_model_custom_metadata.to_dict()["data"]
373384
]
374385

375-
source_freeform_tags = source.freeform_tags or {}
376-
source_freeform_tags.pop(Tags.LICENSE, None)
377-
source_freeform_tags.update({Tags.READY_TO_FINE_TUNE: "false"})
378-
source_freeform_tags.update({Tags.AQUA_TAG: UNKNOWN})
379-
source_freeform_tags.pop(Tags.BASE_MODEL_CUSTOM, None)
386+
model_freeform_tags = source.freeform_tags or {}
387+
model_freeform_tags.pop(Tags.LICENSE, None)
388+
model_freeform_tags.pop(Tags.BASE_MODEL_CUSTOM, None)
389+
390+
model_freeform_tags = {
391+
**model_freeform_tags,
392+
Tags.READY_TO_FINE_TUNE: "false",
393+
Tags.AQUA_TAG: UNKNOWN,
394+
Tags.AQUA_FINE_TUNED_MODEL_TAG: f"{source.id}#{source.display_name}",
395+
**(create_fine_tuning_details.freeform_tags or {}),
396+
}
397+
model_defined_tags = create_fine_tuning_details.defined_tags or {}
380398

381399
self.update_model(
382400
model_id=ft_model.id,
383401
update_model_details=UpdateModelDetails(
384402
custom_metadata_list=updated_custom_metadata_list,
385-
freeform_tags={
386-
Tags.AQUA_FINE_TUNED_MODEL_TAG: (
387-
f"{source.id}#{source.display_name}"
388-
),
389-
**source_freeform_tags,
390-
},
403+
freeform_tags=model_freeform_tags,
404+
defined_tags=model_defined_tags,
391405
),
392406
)
393407

@@ -462,12 +476,16 @@ def create(
462476
region=self.region,
463477
),
464478
),
465-
tags=dict(
466-
aqua_finetuning=Tags.AQUA_FINE_TUNING,
467-
finetuning_job_id=ft_job.id,
468-
finetuning_source=source.id,
469-
finetuning_experiment_id=experiment_model_version_set_id,
470-
),
479+
tags={
480+
**{
481+
"aqua_finetuning": Tags.AQUA_FINE_TUNING,
482+
"finetuning_job_id": ft_job.id,
483+
"finetuning_source": source.id,
484+
"finetuning_experiment_id": experiment_model_version_set_id,
485+
},
486+
**model_freeform_tags,
487+
**model_defined_tags,
488+
},
471489
parameters={
472490
key: value
473491
for key, value in asdict(ft_parameters).items()
@@ -635,6 +653,6 @@ def validate_finetuning_params(self, params: Dict = None) -> Dict:
635653
raise AquaValueError(
636654
f"Invalid fine tuning parameters. Allowable parameters are: "
637655
f"{allowed_fine_tuning_parameters}."
638-
)
656+
) from e
639657

640-
return dict(valid=True)
658+
return {"valid": True}

tests/unitary/with_extras/aqua/test_finetuning.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ def test_create_fine_tuning(
119119
self.app.ds_client.update_model = MagicMock()
120120
self.app.ds_client.update_model_provenance = MagicMock()
121121

122+
ft_model_freeform_tags = {"ftag1": "fvalue1", "ftag2": "fvalue2"}
123+
ft_model_defined_tags = {"dtag1": "dvalue1", "dtag2": "dvalue2"}
124+
122125
create_aqua_ft_details = dict(
123126
ft_source_id="ocid1.datasciencemodel.oc1.iad.<OCID>",
124127
ft_name="test_ft_name",
@@ -134,6 +137,8 @@ def test_create_fine_tuning(
134137
validation_set_size=0.2,
135138
block_storage_size=1,
136139
experiment_name="test_experiment_name",
140+
freeform_tags=ft_model_freeform_tags,
141+
defined_tags=ft_model_defined_tags,
137142
)
138143

139144
aqua_ft_summary = self.app.create(**create_aqua_ft_details)
@@ -167,10 +172,14 @@ def test_create_fine_tuning(
167172
"url": f"https://cloud.oracle.com/data-science/models/{ft_source.id}?region={self.app.region}",
168173
},
169174
"tags": {
170-
"aqua_finetuning": "aqua_finetuning",
171-
"finetuning_experiment_id": f"{mock_mvs_create.return_value[0]}",
172-
"finetuning_job_id": f"{mock_job_id.return_value}",
173-
"finetuning_source": f"{ft_source.id}",
175+
**{
176+
"aqua_finetuning": "aqua_finetuning",
177+
"finetuning_experiment_id": f"{mock_mvs_create.return_value[0]}",
178+
"finetuning_job_id": f"{mock_job_id.return_value}",
179+
"finetuning_source": f"{ft_source.id}",
180+
},
181+
**ft_model_freeform_tags,
182+
**ft_model_defined_tags,
174183
},
175184
"time_created": f"{ft_model.time_created}",
176185
}

0 commit comments

Comments
 (0)