Skip to content

Commit a59b58a

Browse files
add default ft configs
1 parent 7ad70b3 commit a59b58a

File tree

3 files changed

+35
-2
lines changed

3 files changed

+35
-2
lines changed

ads/aqua/config/config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,16 @@
22
# -*- coding: utf-8 -*-
33
# Copyright (c) 2024 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
6+
7+
def get_finetuning_config_defaults():
8+
"""Generate and return the fine-tuning default configuration dictionary."""
9+
return {
10+
"shape": {
11+
"VM.GPU.A10.1": {"batch_size": 1, "replica": "1-10"},
12+
"VM.GPU.A10.2": {"batch_size": 1, "replica": "1-10"},
13+
"BM.GPU.A10.4": {"batch_size": 1, "replica": 1},
14+
"BM.GPU4.8": {"batch_size": 4, "replica": 1},
15+
"BM.GPU.A100-v2.8": {"batch_size": 6, "replica": 1},
16+
}
17+
}

ads/aqua/finetuning/finetuning.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from ads.aqua.common.utils import (
2222
get_container_image,
2323
upload_local_to_os,
24-
get_params_dict,
2524
)
2625
from ads.aqua.constants import (
2726
DEFAULT_FT_BATCH_SIZE,
@@ -32,6 +31,7 @@
3231
UNKNOWN,
3332
UNKNOWN_DICT,
3433
)
34+
from ads.aqua.config.config import get_finetuning_config_defaults
3535
from ads.aqua.data import AquaResourceIdentifier
3636
from ads.aqua.finetuning.constants import *
3737
from ads.aqua.finetuning.entities import *
@@ -553,7 +553,11 @@ def get_finetuning_config(self, model_id: str) -> Dict:
553553
A dict of allowed finetuning configs.
554554
"""
555555

556-
return self.get_config(model_id, AQUA_MODEL_FINETUNING_CONFIG)
556+
config = self.get_config(model_id, AQUA_MODEL_FINETUNING_CONFIG)
557+
if not config:
558+
logger.info(f"Fetching default fine-tuning config for model: {model_id}")
559+
config = get_finetuning_config_defaults()
560+
return config
557561

558562
@telemetry(
559563
entry_point="plugin=finetuning&action=get_finetuning_default_params",

tests/unitary/with_extras/aqua/test_finetuning.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ads.model.datascience_model import DataScienceModel
2727
from ads.model.model_metadata import ModelCustomMetadata
2828
from ads.aqua.common.errors import AquaValueError
29+
from ads.aqua.config.config import get_finetuning_config_defaults
2930

3031

3132
class FineTuningTestCase(TestCase):
@@ -234,6 +235,21 @@ def test_build_oci_launch_cmd(self):
234235
== f"--training_data {dataset_path} --output_dir {report_path} --val_set_size {val_set_size} --num_epochs {parameters.epochs} --learning_rate {parameters.learning_rate} --sample_packing {parameters.sample_packing} --micro_batch_size {parameters.batch_size} --sequence_len {parameters.sequence_len} --lora_target_modules q_proj,k_proj {finetuning_params}"
235236
)
236237

238+
def test_get_finetuning_config(self):
239+
"""Test for fetching config details for a given model to be finetuned."""
240+
241+
config_json = os.path.join(self.curr_dir, "test_data/finetuning/ft_config.json")
242+
with open(config_json, "r") as _file:
243+
config = json.load(_file)
244+
245+
self.app.get_config = MagicMock(return_value=config)
246+
result = self.app.get_finetuning_config(model_id="test-model-id")
247+
assert result == config
248+
249+
self.app.get_config = MagicMock(return_value=None)
250+
result = self.app.get_finetuning_config(model_id="test-model-id")
251+
assert result == get_finetuning_config_defaults()
252+
237253
def test_get_finetuning_default_params(self):
238254
"""Test for fetching finetuning config params for a given model."""
239255

0 commit comments

Comments
 (0)