Skip to content

Commit 2ddfe91

Browse files
update create API validation
1 parent 1f4c353 commit 2ddfe91

File tree

4 files changed

+26
-22
lines changed

4 files changed

+26
-22
lines changed

ads/aqua/common/utils.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44
"""AQUA utils and constants."""
55

@@ -30,6 +30,7 @@
3030
)
3131
from oci.data_science.models import JobRun, Model
3232
from oci.object_storage.models import ObjectSummary
33+
from pydantic import ValidationError
3334

3435
from ads.aqua.common.enums import (
3536
InferenceContainerParamType,
@@ -788,7 +789,9 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
788789
return ocid[-key_len:] if ocid and len(ocid) > key_len else ""
789790

790791

791-
def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None) -> str:
792+
def upload_folder(
793+
os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None
794+
) -> str:
792795
"""Upload the local folder to the object storage
793796
794797
Args:
@@ -1159,3 +1162,15 @@ def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]:
11591162

11601163
combined_cmd_var = cmd_var + overrides
11611164
return combined_cmd_var
1165+
1166+
1167+
def build_pydantic_error_message(ex: ValidationError):
1168+
"""Added to handle error messages from pydantic model validator.
1169+
Combine both loc and msg for errors where loc (field) is present in error details, else only build error
1170+
message using msg field."""
1171+
1172+
return {
1173+
".".join(map(str, e["loc"])): e["msg"]
1174+
for e in ex.errors()
1175+
if "loc" in e and e["loc"]
1176+
} or "; ".join(e["msg"] for e in ex.errors())

ads/aqua/extension/finetune_handler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from ads.aqua.extension.base_handler import AquaAPIhandler
1212
from ads.aqua.extension.errors import Errors
1313
from ads.aqua.finetuning import AquaFineTuningApp
14-
from ads.aqua.finetuning.entities import CreateFineTuningDetails
1514

1615

1716
class AquaFineTuneHandler(AquaAPIhandler):
@@ -48,7 +47,7 @@ def post(self, *args, **kwargs): # noqa: ARG002
4847
if not input_data:
4948
raise HTTPError(400, Errors.NO_INPUT_DATA)
5049

51-
self.finish(AquaFineTuningApp().create(CreateFineTuningDetails(**input_data)))
50+
self.finish(AquaFineTuningApp().create(**input_data))
5251

5352
def get_finetuning_config(self, model_id):
5453
"""Gets the finetuning config for Aqua model."""

ads/aqua/finetuning/finetuning.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from ads.aqua.common.enums import Resource, Tags
1919
from ads.aqua.common.errors import AquaFileExistsError, AquaValueError
2020
from ads.aqua.common.utils import (
21+
build_pydantic_error_message,
2122
get_container_image,
2223
upload_local_to_os,
2324
)
@@ -105,15 +106,11 @@ def create(
105106
try:
106107
create_fine_tuning_details = CreateFineTuningDetails(**kwargs)
107108
except ValidationError as ex:
108-
custom_errors = {
109-
".".join(map(str, e["loc"])): e["msg"] for e in ex.errors()
110-
}
109+
custom_errors = build_pydantic_error_message(ex)
111110
raise AquaValueError(
112111
f"Invalid parameters for creating a fine-tuned model. Error details: {custom_errors}."
113112
) from ex
114113

115-
source = self.get_source(create_fine_tuning_details.ft_source_id)
116-
117114
target_compartment = (
118115
create_fine_tuning_details.compartment_id or COMPARTMENT_OCID
119116
)
@@ -211,6 +208,8 @@ def create(
211208
defined_tags=create_fine_tuning_details.defined_tags,
212209
)
213210

211+
source = self.get_source(create_fine_tuning_details.ft_source_id)
212+
214213
ft_model_custom_metadata = ModelCustomMetadata()
215214
ft_model_custom_metadata.add(
216215
key=FineTuneCustomMetadata.FINE_TUNE_SOURCE,
@@ -615,13 +614,7 @@ def _get_finetuning_params(
615614
**{**params, **{"_validate": validate}}
616615
)
617616
except ValidationError as ex:
618-
# combine both loc and msg for errors where loc (field) is present in error details, else only build error
619-
# message using msg field. Added to handle error messages from pydantic model validator handler.
620-
custom_errors = {
621-
".".join(map(str, e["loc"])): e["msg"]
622-
for e in ex.errors()
623-
if "loc" in e and e["loc"]
624-
} or "; ".join(e["msg"] for e in ex.errors())
617+
custom_errors = build_pydantic_error_message(ex)
625618
raise AquaValueError(
626619
f"Invalid finetuning parameters. Error details: {custom_errors}."
627620
) from ex

tests/unitary/with_extras/aqua/test_finetuning_handler.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*--
33

4-
# Copyright (c) 2024 Oracle and/or its affiliates.
4+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
from unittest import TestCase
88
from unittest.mock import MagicMock
99

1010
from mock import patch
11-
from notebook.base.handlers import APIHandler, IPythonHandler
11+
from notebook.base.handlers import IPythonHandler
1212

1313
from ads.aqua.extension.finetune_handler import (
1414
AquaFineTuneHandler,
1515
AquaFineTuneParamsHandler,
1616
)
1717
from ads.aqua.finetuning import AquaFineTuningApp
18-
from ads.aqua.finetuning.entities import CreateFineTuningDetails
1918

2019

2120
class TestDataset:
@@ -68,9 +67,7 @@ def test_post(self, mock_create):
6867
self.test_instance.post()
6968

7069
self.test_instance.finish.assert_called_with(mock_create.return_value)
71-
mock_create.assert_called_with(
72-
CreateFineTuningDetails(**TestDataset.mock_valid_input)
73-
)
70+
mock_create.assert_called_with(**TestDataset.mock_valid_input)
7471

7572
@patch.object(AquaFineTuningApp, "get_finetuning_config")
7673
def test_get_finetuning_config(self, mock_get_finetuning_config):

0 commit comments

Comments
 (0)