File tree Expand file tree Collapse file tree 3 files changed +35
-2
lines changed Expand file tree Collapse file tree 3 files changed +35
-2
lines changed Original file line number Diff line number Diff line change 2
2
# -*- coding: utf-8 -*-
3
3
# Copyright (c) 2024 Oracle and/or its affiliates.
4
4
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
+
6
+
7
+ def get_finetuning_config_defaults ():
8
+ """Generate and return the fine-tuning default configuration dictionary."""
9
+ return {
10
+ "shape" : {
11
+ "VM.GPU.A10.1" : {"batch_size" : 1 , "replica" : "1-10" },
12
+ "VM.GPU.A10.2" : {"batch_size" : 1 , "replica" : "1-10" },
13
+ "BM.GPU.A10.4" : {"batch_size" : 1 , "replica" : 1 },
14
+ "BM.GPU4.8" : {"batch_size" : 4 , "replica" : 1 },
15
+ "BM.GPU.A100-v2.8" : {"batch_size" : 6 , "replica" : 1 },
16
+ }
17
+ }
Original file line number Diff line number Diff line change 21
21
from ads .aqua .common .utils import (
22
22
get_container_image ,
23
23
upload_local_to_os ,
24
- get_params_dict ,
25
24
)
26
25
from ads .aqua .constants import (
27
26
DEFAULT_FT_BATCH_SIZE ,
32
31
UNKNOWN ,
33
32
UNKNOWN_DICT ,
34
33
)
34
+ from ads .aqua .config .config import get_finetuning_config_defaults
35
35
from ads .aqua .data import AquaResourceIdentifier
36
36
from ads .aqua .finetuning .constants import *
37
37
from ads .aqua .finetuning .entities import *
@@ -553,7 +553,11 @@ def get_finetuning_config(self, model_id: str) -> Dict:
553
553
A dict of allowed finetuning configs.
554
554
"""
555
555
556
- return self .get_config (model_id , AQUA_MODEL_FINETUNING_CONFIG )
556
+ config = self .get_config (model_id , AQUA_MODEL_FINETUNING_CONFIG )
557
+ if not config :
558
+ logger .info (f"Fetching default fine-tuning config for model: { model_id } " )
559
+ config = get_finetuning_config_defaults ()
560
+ return config
557
561
558
562
@telemetry (
559
563
entry_point = "plugin=finetuning&action=get_finetuning_default_params" ,
Original file line number Diff line number Diff line change 26
26
from ads .model .datascience_model import DataScienceModel
27
27
from ads .model .model_metadata import ModelCustomMetadata
28
28
from ads .aqua .common .errors import AquaValueError
29
+ from ads .aqua .config .config import get_finetuning_config_defaults
29
30
30
31
31
32
class FineTuningTestCase (TestCase ):
@@ -234,6 +235,21 @@ def test_build_oci_launch_cmd(self):
234
235
== f"--training_data { dataset_path } --output_dir { report_path } --val_set_size { val_set_size } --num_epochs { parameters .epochs } --learning_rate { parameters .learning_rate } --sample_packing { parameters .sample_packing } --micro_batch_size { parameters .batch_size } --sequence_len { parameters .sequence_len } --lora_target_modules q_proj,k_proj { finetuning_params } "
235
236
)
236
237
238
+ def test_get_finetuning_config (self ):
239
+ """Test for fetching config details for a given model to be finetuned."""
240
+
241
+ config_json = os .path .join (self .curr_dir , "test_data/finetuning/ft_config.json" )
242
+ with open (config_json , "r" ) as _file :
243
+ config = json .load (_file )
244
+
245
+ self .app .get_config = MagicMock (return_value = config )
246
+ result = self .app .get_finetuning_config (model_id = "test-model-id" )
247
+ assert result == config
248
+
249
+ self .app .get_config = MagicMock (return_value = None )
250
+ result = self .app .get_finetuning_config (model_id = "test-model-id" )
251
+ assert result == get_finetuning_config_defaults ()
252
+
237
253
def test_get_finetuning_default_params (self ):
238
254
"""Test for fetching finetuning config params for a given model."""
239
255
You can’t perform that action at this time.
0 commit comments