Skip to content

Commit 227a177

Browse files
Set default FT configs for unverified models and update validation API (#863)
2 parents 1371ad9 + 7c62b0a commit 227a177

File tree

3 files changed

+71
-36
lines changed

3 files changed

+71
-36
lines changed

ads/aqua/config/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,17 @@
22
# -*- coding: utf-8 -*-
33
# Copyright (c) 2024 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
6+
7+
# TODO: move this to global config.json in object storage
8+
def get_finetuning_config_defaults():
9+
"""Generate and return the fine-tuning default configuration dictionary."""
10+
return {
11+
"shape": {
12+
"VM.GPU.A10.1": {"batch_size": 1, "replica": "1-10"},
13+
"VM.GPU.A10.2": {"batch_size": 1, "replica": "1-10"},
14+
"BM.GPU.A10.4": {"batch_size": 1, "replica": 1},
15+
"BM.GPU4.8": {"batch_size": 4, "replica": 1},
16+
"BM.GPU.A100-v2.8": {"batch_size": 6, "replica": 1},
17+
}
18+
}

ads/aqua/finetuning/finetuning.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import json
77
import os
8-
from dataclasses import asdict, fields
8+
from dataclasses import asdict, fields, MISSING
99
from typing import Dict
1010

1111
from oci.data_science.models import (
@@ -21,7 +21,6 @@
2121
from ads.aqua.common.utils import (
2222
get_container_image,
2323
upload_local_to_os,
24-
get_params_dict,
2524
)
2625
from ads.aqua.constants import (
2726
DEFAULT_FT_BATCH_SIZE,
@@ -32,6 +31,7 @@
3231
UNKNOWN,
3332
UNKNOWN_DICT,
3433
)
34+
from ads.aqua.config.config import get_finetuning_config_defaults
3535
from ads.aqua.data import AquaResourceIdentifier
3636
from ads.aqua.finetuning.constants import *
3737
from ads.aqua.finetuning.entities import *
@@ -553,7 +553,11 @@ def get_finetuning_config(self, model_id: str) -> Dict:
553553
A dict of allowed finetuning configs.
554554
"""
555555

556-
return self.get_config(model_id, AQUA_MODEL_FINETUNING_CONFIG)
556+
config = self.get_config(model_id, AQUA_MODEL_FINETUNING_CONFIG)
557+
if not config:
558+
logger.info(f"Fetching default fine-tuning config for model: {model_id}")
559+
config = get_finetuning_config_defaults()
560+
return config
557561

558562
@telemetry(
559563
entry_point="plugin=finetuning&action=get_finetuning_default_params",
@@ -586,31 +590,32 @@ def get_finetuning_default_params(self, model_id: str) -> Dict:
586590

587591
return default_params
588592

589-
def validate_finetuning_params(self, params: List[str] = None) -> Dict:
593+
def validate_finetuning_params(self, params: Dict = None) -> Dict:
590594
"""Validate if the fine-tuning parameters passed by the user can be overridden. Parameter values are not
591595
validated, only param keys are validated.
592596
593597
Parameters
594598
----------
595-
params : List[str], optional
599+
params :Dict, optional
596600
Params passed by the user.
597601
598602
Returns
599603
-------
600604
Return a list of restricted params.
601605
"""
602-
restricted_params = []
603-
if params:
604-
dataclass_fields = {field.name for field in fields(AquaFineTuningParams)}
605-
params_dict = get_params_dict(params)
606-
for key, items in params_dict.items():
607-
key = key.lstrip("--")
608-
if key not in dataclass_fields:
609-
restricted_params.append(key)
610-
611-
if restricted_params:
606+
try:
607+
AquaFineTuningParams(
608+
**params,
609+
)
610+
except Exception as e:
611+
logger.debug(str(e))
612+
allowed_fine_tuning_parameters = ", ".join(
613+
f"{field.name} (required)" if field.default is MISSING else field.name
614+
for field in fields(AquaFineTuningParams)
615+
).rstrip()
612616
raise AquaValueError(
613-
f"Parameters {restricted_params} are set by Aqua "
614-
f"and cannot be overridden or are invalid."
617+
f"Invalid fine tuning parameters. Allowable parameters are: "
618+
f"{allowed_fine_tuning_parameters}."
615619
)
620+
616621
return dict(valid=True)

tests/unitary/with_extras/aqua/test_finetuning.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ads.model.datascience_model import DataScienceModel
2727
from ads.model.model_metadata import ModelCustomMetadata
2828
from ads.aqua.common.errors import AquaValueError
29+
from ads.aqua.config.config import get_finetuning_config_defaults
2930

3031

3132
class FineTuningTestCase(TestCase):
@@ -234,6 +235,21 @@ def test_build_oci_launch_cmd(self):
234235
== f"--training_data {dataset_path} --output_dir {report_path} --val_set_size {val_set_size} --num_epochs {parameters.epochs} --learning_rate {parameters.learning_rate} --sample_packing {parameters.sample_packing} --micro_batch_size {parameters.batch_size} --sequence_len {parameters.sequence_len} --lora_target_modules q_proj,k_proj {finetuning_params}"
235236
)
236237

238+
def test_get_finetuning_config(self):
239+
"""Test for fetching config details for a given model to be finetuned."""
240+
241+
config_json = os.path.join(self.curr_dir, "test_data/finetuning/ft_config.json")
242+
with open(config_json, "r") as _file:
243+
config = json.load(_file)
244+
245+
self.app.get_config = MagicMock(return_value=config)
246+
result = self.app.get_finetuning_config(model_id="test-model-id")
247+
assert result == config
248+
249+
self.app.get_config = MagicMock(return_value=None)
250+
result = self.app.get_finetuning_config(model_id="test-model-id")
251+
assert result == get_finetuning_config_defaults()
252+
237253
def test_get_finetuning_default_params(self):
238254
"""Test for fetching finetuning config params for a given model."""
239255

@@ -267,28 +283,28 @@ def test_get_finetuning_default_params(self):
267283
@parameterized.expand(
268284
[
269285
(
270-
[
271-
"--batch_size 1",
272-
"--sequence_len 2048",
273-
"--sample_packing true",
274-
"--pad_to_sequence_len true",
275-
"--learning_rate 0.0002",
276-
"--lora_r 32",
277-
"--lora_alpha 16",
278-
"--lora_dropout 0.05",
279-
"--lora_target_linear true",
280-
"--lora_target_modules q_proj,k_proj",
281-
],
286+
{
287+
"epochs": 1,
288+
"learning_rate": 0.0002,
289+
"batch_size": 1,
290+
"sequence_len": 2048,
291+
"sample_packing": True,
292+
"pad_to_sequence_len": True,
293+
"lora_alpha": 16,
294+
"lora_dropout": 0.05,
295+
"lora_target_linear": True,
296+
"lora_target_modules": ["q_proj", " k_proj"],
297+
},
282298
True,
283299
),
284300
(
285-
[
286-
"--micro_batch_size 1",
287-
"--max_sequence_len 2048",
288-
"--flash_attention true",
289-
"--pad_to_sequence_len true",
290-
"--lr_scheduler cosine",
291-
],
301+
{
302+
"micro_batch_size": 1,
303+
"max_sequence_len": 2048,
304+
"flash_attention": True,
305+
"pad_to_sequence_len": True,
306+
"lr_scheduler": "cosine",
307+
},
292308
False,
293309
),
294310
]

0 commit comments

Comments
 (0)