Skip to content

Commit e9267de

Browse files
Adding AquaFineTuningConfig
1 parent 8fd603e commit e9267de

File tree

4 files changed

+51
-17
lines changed

4 files changed

+51
-17
lines changed

ads/aqua/finetuning/entities.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55
import json
6-
from typing import List, Literal, Optional, Union
6+
from typing import Any, Dict, List, Literal, Optional, Union
77

88
from pydantic import Field, model_validator
99

@@ -55,6 +55,28 @@ def validate_restricted_fields(cls, data: dict):
5555
return data
5656

5757

58+
class AquaFineTuningConfig(Serializable):
59+
"""Represents model's shape list and detailed configuration for fine-tuning.
60+
61+
Attributes:
62+
shape (List[str], optional): A list of shape names (e.g., BM.GPU.A10.4).
63+
configuration (Dict[str, Any], optional): Configuration details of fine-tuning.
64+
"""
65+
66+
shape: Optional[Dict[str, Any]] = Field(
67+
default_factory=dict, description="List of supported shapes for the model."
68+
)
69+
finetuning_params: Optional[str] = Field(
70+
default_factory=str, description="Fine tuning parameters."
71+
)
72+
configuration: Optional[Dict[str, Any]] = Field(
73+
default_factory=dict, description="Configuration details keyed by shape."
74+
)
75+
76+
class Config:
77+
extra = "allow"
78+
79+
5880
class AquaFineTuningSummary(Serializable):
5981
"""Represents a summary of Aqua Finetuning job."""
6082

ads/aqua/finetuning/finetuning.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
FineTuneCustomMetadata,
3939
)
4040
from ads.aqua.finetuning.entities import (
41+
AquaFineTuningConfig,
4142
AquaFineTuningParams,
4243
AquaFineTuningSummary,
4344
CreateFineTuningDetails,
@@ -371,11 +372,11 @@ def create(
371372
is_custom_container = True
372373

373374
ft_parameters.batch_size = ft_parameters.batch_size or (
374-
ft_config.get("shape", UNKNOWN_DICT)
375+
(ft_config.shape if ft_config else UNKNOWN_DICT)
375376
.get(create_fine_tuning_details.shape_name, UNKNOWN_DICT)
376377
.get("batch_size", DEFAULT_FT_BATCH_SIZE)
377378
)
378-
finetuning_params = ft_config.get("finetuning_params")
379+
finetuning_params = ft_config.finetuning_params if ft_config else UNKNOWN
379380

380381
ft_job.with_runtime(
381382
self._build_fine_tuning_runtime(
@@ -626,7 +627,7 @@ def _build_oci_launch_cmd(
626627
@telemetry(
627628
entry_point="plugin=finetuning&action=get_finetuning_config", name="aqua"
628629
)
629-
def get_finetuning_config(self, model_id: str) -> Dict:
630+
def get_finetuning_config(self, model_id: str) -> AquaFineTuningConfig:
630631
"""Gets the finetuning config for given Aqua model.
631632
632633
Parameters
@@ -641,12 +642,12 @@ def get_finetuning_config(self, model_id: str) -> Dict:
641642
"""
642643
config = self.get_config_from_metadata(
643644
model_id, AquaModelMetadataKeys.FINE_TUNING_CONFIGURATION
644-
)
645+
).config
645646
if config:
646647
logger.info(
647648
f"Fetched {AquaModelMetadataKeys.FINE_TUNING_CONFIGURATION} from defined metadata for model: {model_id}."
648649
)
649-
return config
650+
return AquaFineTuningConfig(**(config or UNKNOWN_DICT))
650651
config = self.get_config(
651652
model_id,
652653
DEFINED_METADATA_TO_FILE_MAP.get(
@@ -657,7 +658,7 @@ def get_finetuning_config(self, model_id: str) -> Dict:
657658
logger.debug(
658659
f"Fine-tuning config for custom model: {model_id} is not available. Use defaults."
659660
)
660-
return config
661+
return AquaFineTuningConfig(**(config or UNKNOWN_DICT))
661662

662663
@telemetry(
663664
entry_point="plugin=finetuning&action=get_finetuning_default_params",
@@ -680,7 +681,9 @@ def get_finetuning_default_params(self, model_id: str) -> Dict:
680681
"""
681682
default_params = {"params": {}}
682683
finetuning_config = self.get_finetuning_config(model_id)
683-
config_parameters = finetuning_config.get("configuration", UNKNOWN_DICT)
684+
config_parameters = (
685+
finetuning_config.configuration if finetuning_config else UNKNOWN_DICT
686+
)
684687
dataclass_fields = self._get_finetuning_params(
685688
config_parameters, validate=False
686689
).to_dict()

tests/unitary/with_extras/aqua/test_data/finetuning/ft_config.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"sequence_len": 2048,
1111
"val_set_size": 0.1
1212
},
13+
"finetuning_params": "",
1314
"shape": {
1415
"VM.GPU.A10.2": {
1516
"batch_size": 2,

tests/unitary/with_extras/aqua/test_finetuning.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ads.aqua.common.errors import AquaValueError
2222
from ads.aqua.finetuning import AquaFineTuningApp
2323
from ads.aqua.finetuning.constants import FineTuneCustomMetadata
24-
from ads.aqua.finetuning.entities import AquaFineTuningParams
24+
from ads.aqua.finetuning.entities import AquaFineTuningParams, AquaFineTuningConfig
2525
from ads.aqua.model.entities import AquaFineTuneModel
2626
from ads.jobs.ads_job import Job
2727
from ads.jobs.builders.infrastructure.dsc_job import DataScienceJobRun
@@ -113,11 +113,13 @@ def test_create_fine_tuning(
113113
ft_model.time_created = "test_time_created"
114114
mock_ds_model_create.return_value = ft_model
115115

116-
mock_get_finetuning_config.return_value = {
117-
"shape": {
118-
"VM.GPU.A10.1": {"batch_size": 1, "replica": 1},
116+
mock_get_finetuning_config.return_value = AquaFineTuningConfig(
117+
**{
118+
"shape": {
119+
"VM.GPU.A10.1": {"batch_size": 1, "replica": 1},
120+
}
119121
}
120-
}
122+
)
121123
mock_get_container_image.return_value = "test_container_image"
122124

123125
mock_job_id.return_value = "test_ft_job_id"
@@ -279,11 +281,13 @@ def test_get_finetuning_config(self):
279281
config_json = os.path.join(self.curr_dir, "test_data/finetuning/ft_config.json")
280282
with open(config_json, "r") as _file:
281283
config = json.load(_file)
282-
self.app.get_config_from_metadata = MagicMock(return_value={})
284+
self.app.get_config_from_metadata = MagicMock(
285+
return_value=ModelConfigResult(config=config, model_details=None)
286+
)
283287
self.app.get_config = MagicMock(
284288
return_value=ModelConfigResult(config=config, model_details=None)
285289
)
286-
result = self.app.get_finetuning_config(model_id="test-model-id")
290+
result = self.app.get_finetuning_config(model_id="test-model-id").to_dict()
287291
assert result == config
288292

289293
def test_get_finetuning_default_params(self):
@@ -306,12 +310,16 @@ def test_get_finetuning_default_params(self):
306310
with open(config_json, "r") as _file:
307311
config = json.load(_file)
308312

309-
self.app.get_finetuning_config = MagicMock(return_value=config)
313+
self.app.get_finetuning_config = MagicMock(
314+
return_value=AquaFineTuningConfig(**config)
315+
)
310316
result = self.app.get_finetuning_default_params(model_id="test_model_id")
311317
assert result == params_dict
312318

313319
# check when config json is not available
314-
self.app.get_finetuning_config = MagicMock(return_value={})
320+
self.app.get_finetuning_config = MagicMock(
321+
return_value=AquaFineTuningConfig(**{})
322+
)
315323
result = self.app.get_finetuning_default_params(model_id="test_model_id")
316324
assert result == {"params": {}}
317325

0 commit comments

Comments
 (0)