Skip to content

Commit 1f4c353

Browse files
improve validation
1 parent e2b3bce commit 1f4c353

File tree

3 files changed

+26
-15
lines changed

3 files changed

+26
-15
lines changed

ads/aqua/finetuning/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class FineTuneCustomMetadata(str, metaclass=ExtendedEnumMeta):
1616
SERVICE_MODEL_FINE_TUNE_CONTAINER = "finetune-container"
1717

1818

19-
class FineTuningForbiddenParams(str, metaclass=ExtendedEnumMeta):
19+
class FineTuningRestrictedParams(str, metaclass=ExtendedEnumMeta):
2020
OPTIMIZER = "optimizer"
2121

2222

ads/aqua/finetuning/entities.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,22 @@
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55
import json
6-
from typing import List, Optional
6+
from typing import List, Literal, Optional, Union
77

88
from pydantic import Field, model_validator
99

1010
from ads.aqua.common.errors import AquaValueError
1111
from ads.aqua.config.utils.serializer import Serializable
1212
from ads.aqua.data import AquaResourceIdentifier
13-
from ads.aqua.finetuning.constants import FineTuningForbiddenParams
13+
from ads.aqua.finetuning.constants import FineTuningRestrictedParams
1414

1515

1616
class AquaFineTuningParams(Serializable):
1717
"""Class for maintaining aqua fine-tuning model parameters"""
1818

1919
epochs: Optional[int] = None
2020
learning_rate: Optional[float] = None
21-
sample_packing: Optional[bool] = "auto"
21+
sample_packing: Union[bool, None, Literal["auto"]] = "auto"
2222
batch_size: Optional[int] = (
2323
None # make it batch_size for user, but internally this is micro_batch_size
2424
)
@@ -40,16 +40,18 @@ def to_dict(self) -> dict:
4040

4141
@model_validator(mode="before")
4242
@classmethod
43-
def validate_forbidden_fields(cls, data: dict):
43+
def validate_restricted_fields(cls, data: dict):
4444
# we may want to skip validation if loading data from config files instead of user entered parameters
4545
validate = data.pop("_validate", True)
4646
if not (validate and isinstance(data, dict)):
4747
return data
48-
forbidden_params = [
49-
param for param in data if param in FineTuningForbiddenParams.values()
48+
restricted_params = [
49+
param for param in data if param in FineTuningRestrictedParams.values()
5050
]
51-
if forbidden_params:
52-
raise AquaValueError(f"Found restricted parameter name: {forbidden_params}")
51+
if restricted_params:
52+
raise AquaValueError(
53+
f"Found restricted parameter name: {restricted_params}"
54+
)
5355
return data
5456

5557

ads/aqua/finetuning/finetuning.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def create(
152152
f"Logging is required for fine tuning if replica is larger than {DEFAULT_FT_REPLICA}."
153153
)
154154

155-
ft_parameters = self._validate_finetuning_params(
155+
ft_parameters = self._get_finetuning_params(
156156
create_fine_tuning_details.ft_parameters
157157
)
158158

@@ -591,8 +591,9 @@ def get_finetuning_default_params(self, model_id: str) -> Dict:
591591
default_params = {"params": {}}
592592
finetuning_config = self.get_finetuning_config(model_id)
593593
config_parameters = finetuning_config.get("configuration", UNKNOWN_DICT)
594-
config_parameters["_validate"] = False
595-
dataclass_fields = AquaFineTuningParams(**config_parameters).to_dict()
594+
dataclass_fields = self._get_finetuning_params(
595+
config_parameters, validate=False
596+
).to_dict()
596597
for name, value in config_parameters.items():
597598
if name in dataclass_fields:
598599
if name == "micro_batch_size":
@@ -602,9 +603,17 @@ def get_finetuning_default_params(self, model_id: str) -> Dict:
602603
return default_params
603604

604605
@staticmethod
605-
def _validate_finetuning_params(params: Dict = None) -> AquaFineTuningParams:
606+
def _get_finetuning_params(
607+
params: Dict = None, validate: bool = True
608+
) -> AquaFineTuningParams:
609+
"""
610+
Get and validate the fine-tuning params, and return an error message if validation fails. In order to skip
611+
@model_validator decorator's validation, pass validate=False.
612+
"""
606613
try:
607-
finetuning_params = AquaFineTuningParams(**params)
614+
finetuning_params = AquaFineTuningParams(
615+
**{**params, **{"_validate": validate}}
616+
)
608617
except ValidationError as ex:
609618
# combine both loc and msg for errors where loc (field) is present in error details, else only build error
610619
# message using msg field. Added to handle error messages from pydantic model validator handler.
@@ -631,5 +640,5 @@ def validate_finetuning_params(self, params: Dict = None) -> Dict:
631640
-------
632641
Return a list of restricted params.
633642
"""
634-
self._validate_finetuning_params(params or {})
643+
self._get_finetuning_params(params or {})
635644
return {"valid": True}

0 commit comments

Comments
 (0)