35
35
ENV_AQUA_FINE_TUNING_CONTAINER ,
36
36
FineTuneCustomMetadata ,
37
37
)
38
- from ads .aqua .finetuning .entities import *
38
+ from ads .aqua .finetuning .entities import (
39
+ AquaFineTuningParams ,
40
+ AquaFineTuningSummary ,
41
+ CreateFineTuningDetails ,
42
+ )
39
43
from ads .common .auth import default_signer
40
44
from ads .common .object_storage_details import ObjectStorageDetails
41
45
from ads .common .utils import get_console_link
@@ -100,14 +104,14 @@ def create(
100
104
if not create_fine_tuning_details :
101
105
try :
102
106
create_fine_tuning_details = CreateFineTuningDetails (** kwargs )
103
- except :
107
+ except Exception as ex :
104
108
allowed_create_fine_tuning_details = ", " .join (
105
109
field .name for field in fields (CreateFineTuningDetails )
106
110
).rstrip ()
107
111
raise AquaValueError (
108
112
"Invalid create fine tuning parameters. Allowable parameters are: "
109
113
f"{ allowed_create_fine_tuning_details } ."
110
- )
114
+ ) from ex
111
115
112
116
source = self .get_source (create_fine_tuning_details .ft_source_id )
113
117
@@ -148,28 +152,27 @@ def create(
148
152
"Specify the subnet id via API or environment variable AQUA_JOB_SUBNET_ID."
149
153
)
150
154
151
- if create_fine_tuning_details .replica > DEFAULT_FT_REPLICA :
152
- if not (
153
- create_fine_tuning_details .log_id
154
- and create_fine_tuning_details .log_group_id
155
- ):
156
- raise AquaValueError (
157
- f"Logging is required for fine tuning if replica is larger than { DEFAULT_FT_REPLICA } ."
158
- )
155
+ if create_fine_tuning_details .replica > DEFAULT_FT_REPLICA and not (
156
+ create_fine_tuning_details .log_id
157
+ and create_fine_tuning_details .log_group_id
158
+ ):
159
+ raise AquaValueError (
160
+ f"Logging is required for fine tuning if replica is larger than { DEFAULT_FT_REPLICA } ."
161
+ )
159
162
160
163
ft_parameters = None
161
164
try :
162
165
ft_parameters = AquaFineTuningParams (
163
166
** create_fine_tuning_details .ft_parameters ,
164
167
)
165
- except :
168
+ except Exception as ex :
166
169
allowed_fine_tuning_parameters = ", " .join (
167
170
field .name for field in fields (AquaFineTuningParams )
168
171
).rstrip ()
169
172
raise AquaValueError (
170
173
"Invalid fine tuning parameters. Fine tuning parameters should "
171
174
f"be a dictionary with keys: { allowed_fine_tuning_parameters } ."
172
- )
175
+ ) from ex
173
176
174
177
experiment_model_version_set_id = create_fine_tuning_details .experiment_id
175
178
experiment_model_version_set_name = create_fine_tuning_details .experiment_name
@@ -197,11 +200,11 @@ def create(
197
200
auth = default_signer (),
198
201
force_overwrite = create_fine_tuning_details .force_overwrite ,
199
202
)
200
- except FileExistsError :
203
+ except FileExistsError as fe :
201
204
raise AquaFileExistsError (
202
205
f"Dataset { dataset_file } already exists in { create_fine_tuning_details .report_path } . "
203
206
"Please use a new dataset file name, report path or set `force_overwrite` as True."
204
- )
207
+ ) from fe
205
208
logger .debug (
206
209
f"Uploaded local file { ft_dataset_path } to object storage { dst_uri } ."
207
210
)
@@ -222,6 +225,8 @@ def create(
222
225
description = create_fine_tuning_details .experiment_description ,
223
226
compartment_id = target_compartment ,
224
227
project_id = target_project ,
228
+ freeform_tags = create_fine_tuning_details .freeform_tags ,
229
+ defined_tags = create_fine_tuning_details .defined_tags ,
225
230
)
226
231
227
232
ft_model_custom_metadata = ModelCustomMetadata ()
@@ -273,6 +278,10 @@ def create(
273
278
Tags .AQUA_TAG : UNKNOWN ,
274
279
Tags .AQUA_FINE_TUNED_MODEL_TAG : f"{ source .id } #{ source .display_name } " ,
275
280
}
281
+ ft_job_freeform_tags = {
282
+ ** ft_job_freeform_tags ,
283
+ ** (create_fine_tuning_details .freeform_tags or {}),
284
+ }
276
285
277
286
ft_job = Job (name = ft_model .display_name ).with_infrastructure (
278
287
DataScienceJob ()
@@ -286,6 +295,7 @@ def create(
286
295
or DEFAULT_FT_BLOCK_STORAGE_SIZE
287
296
)
288
297
.with_freeform_tag (** ft_job_freeform_tags )
298
+ .with_defined_tag (** (create_fine_tuning_details .defined_tags or {}))
289
299
)
290
300
291
301
if not subnet_id :
@@ -353,6 +363,7 @@ def create(
353
363
ft_job_run = ft_job .run (
354
364
name = ft_model .display_name ,
355
365
freeform_tags = ft_job_freeform_tags ,
366
+ defined_tags = create_fine_tuning_details .defined_tags or {},
356
367
wait = False ,
357
368
)
358
369
logger .debug (
@@ -372,22 +383,25 @@ def create(
372
383
for metadata in ft_model_custom_metadata .to_dict ()["data" ]
373
384
]
374
385
375
- source_freeform_tags = source .freeform_tags or {}
376
- source_freeform_tags .pop (Tags .LICENSE , None )
377
- source_freeform_tags .update ({Tags .READY_TO_FINE_TUNE : "false" })
378
- source_freeform_tags .update ({Tags .AQUA_TAG : UNKNOWN })
379
- source_freeform_tags .pop (Tags .BASE_MODEL_CUSTOM , None )
386
+ model_freeform_tags = source .freeform_tags or {}
387
+ model_freeform_tags .pop (Tags .LICENSE , None )
388
+ model_freeform_tags .pop (Tags .BASE_MODEL_CUSTOM , None )
389
+
390
+ model_freeform_tags = {
391
+ ** model_freeform_tags ,
392
+ Tags .READY_TO_FINE_TUNE : "false" ,
393
+ Tags .AQUA_TAG : UNKNOWN ,
394
+ Tags .AQUA_FINE_TUNED_MODEL_TAG : f"{ source .id } #{ source .display_name } " ,
395
+ ** (create_fine_tuning_details .freeform_tags or {}),
396
+ }
397
+ model_defined_tags = create_fine_tuning_details .defined_tags or {}
380
398
381
399
self .update_model (
382
400
model_id = ft_model .id ,
383
401
update_model_details = UpdateModelDetails (
384
402
custom_metadata_list = updated_custom_metadata_list ,
385
- freeform_tags = {
386
- Tags .AQUA_FINE_TUNED_MODEL_TAG : (
387
- f"{ source .id } #{ source .display_name } "
388
- ),
389
- ** source_freeform_tags ,
390
- },
403
+ freeform_tags = model_freeform_tags ,
404
+ defined_tags = model_defined_tags ,
391
405
),
392
406
)
393
407
@@ -462,12 +476,16 @@ def create(
462
476
region = self .region ,
463
477
),
464
478
),
465
- tags = dict (
466
- aqua_finetuning = Tags .AQUA_FINE_TUNING ,
467
- finetuning_job_id = ft_job .id ,
468
- finetuning_source = source .id ,
469
- finetuning_experiment_id = experiment_model_version_set_id ,
470
- ),
479
+ tags = {
480
+ ** {
481
+ "aqua_finetuning" : Tags .AQUA_FINE_TUNING ,
482
+ "finetuning_job_id" : ft_job .id ,
483
+ "finetuning_source" : source .id ,
484
+ "finetuning_experiment_id" : experiment_model_version_set_id ,
485
+ },
486
+ ** model_freeform_tags ,
487
+ ** model_defined_tags ,
488
+ },
471
489
parameters = {
472
490
key : value
473
491
for key , value in asdict (ft_parameters ).items ()
@@ -635,6 +653,6 @@ def validate_finetuning_params(self, params: Dict = None) -> Dict:
635
653
raise AquaValueError (
636
654
f"Invalid fine tuning parameters. Allowable parameters are: "
637
655
f"{ allowed_fine_tuning_parameters } ."
638
- )
656
+ ) from e
639
657
640
- return dict ( valid = True )
658
+ return { " valid" : True }
0 commit comments