Skip to content

Commit 7ad70b3

Browse files
update ft validation api
1 parent 37a46a8 commit 7ad70b3

File tree

2 files changed

+35
-34
lines changed

2 files changed

+35
-34
lines changed

ads/aqua/finetuning/finetuning.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import json
77
import os
8-
from dataclasses import asdict, fields
8+
from dataclasses import asdict, fields, MISSING
99
from typing import Dict
1010

1111
from oci.data_science.models import (
@@ -586,31 +586,32 @@ def get_finetuning_default_params(self, model_id: str) -> Dict:
586586

587587
return default_params
588588

589-
def validate_finetuning_params(self, params: List[str] = None) -> Dict:
589+
def validate_finetuning_params(self, params: Dict = None) -> Dict:
590590
"""Validate if the fine-tuning parameters passed by the user can be overridden. Parameter values are not
591591
validated, only param keys are validated.
592592
593593
Parameters
594594
----------
595-
params : List[str], optional
595+
params :Dict, optional
596596
Params passed by the user.
597597
598598
Returns
599599
-------
600600
Return a list of restricted params.
601601
"""
602-
restricted_params = []
603-
if params:
604-
dataclass_fields = {field.name for field in fields(AquaFineTuningParams)}
605-
params_dict = get_params_dict(params)
606-
for key, items in params_dict.items():
607-
key = key.lstrip("--")
608-
if key not in dataclass_fields:
609-
restricted_params.append(key)
610-
611-
if restricted_params:
602+
try:
603+
AquaFineTuningParams(
604+
**params,
605+
)
606+
except Exception as e:
607+
logger.debug(str(e))
608+
allowed_fine_tuning_parameters = ", ".join(
609+
f"{field.name} (required)" if field.default is MISSING else field.name
610+
for field in fields(AquaFineTuningParams)
611+
).rstrip()
612612
raise AquaValueError(
613-
f"Parameters {restricted_params} are set by Aqua "
614-
f"and cannot be overridden or are invalid."
613+
f"Invalid fine tuning parameters. Allowable parameters are: "
614+
f"{allowed_fine_tuning_parameters}."
615615
)
616+
616617
return dict(valid=True)

tests/unitary/with_extras/aqua/test_finetuning.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -267,28 +267,28 @@ def test_get_finetuning_default_params(self):
267267
@parameterized.expand(
268268
[
269269
(
270-
[
271-
"--batch_size 1",
272-
"--sequence_len 2048",
273-
"--sample_packing true",
274-
"--pad_to_sequence_len true",
275-
"--learning_rate 0.0002",
276-
"--lora_r 32",
277-
"--lora_alpha 16",
278-
"--lora_dropout 0.05",
279-
"--lora_target_linear true",
280-
"--lora_target_modules q_proj,k_proj",
281-
],
270+
{
271+
"epochs": 1,
272+
"learning_rate": 0.0002,
273+
"batch_size": 1,
274+
"sequence_len": 2048,
275+
"sample_packing": True,
276+
"pad_to_sequence_len": True,
277+
"lora_alpha": 16,
278+
"lora_dropout": 0.05,
279+
"lora_target_linear": True,
280+
"lora_target_modules": ["q_proj", " k_proj"],
281+
},
282282
True,
283283
),
284284
(
285-
[
286-
"--micro_batch_size 1",
287-
"--max_sequence_len 2048",
288-
"--flash_attention true",
289-
"--pad_to_sequence_len true",
290-
"--lr_scheduler cosine",
291-
],
285+
{
286+
"micro_batch_size": 1,
287+
"max_sequence_len": 2048,
288+
"flash_attention": True,
289+
"pad_to_sequence_len": True,
290+
"lr_scheduler": "cosine",
291+
},
292292
False,
293293
),
294294
]

0 commit comments

Comments
 (0)