Skip to content

Commit 7273504

Browse files
authored
Merge branch 'main' into aqua_client
2 parents 727fb75 + 21ba00b commit 7273504

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+4628
-556
lines changed

.github/workflows/run-forecast-unit-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,6 @@ jobs:
5656
$CONDA/bin/conda init
5757
source /home/runner/.bashrc
5858
pip install -r test-requirements-operators.txt
59-
pip install "oracle-automlx[forecasting]>=24.4.0"
59+
pip install "oracle-automlx[forecasting]>=24.4.1"
6060
pip install pandas>=2.2.0
6161
python -m pytest -v -p no:warnings --durations=5 tests/operators/forecast

ads/aqua/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55

66
import os
7+
from logging import getLogger
78

89
from ads import logger, set_auth
910
from ads.aqua.client.client import AsyncClient, Client
@@ -19,6 +20,7 @@ def get_logger_level():
1920
return level
2021

2122

23+
logger = getLogger(__name__)
2224
logger.setLevel(get_logger_level())
2325

2426

@@ -27,7 +29,6 @@ def set_log_level(log_level: str):
2729

2830
log_level = log_level.upper()
2931
logger.setLevel(log_level.upper())
30-
logger.handlers[0].setLevel(log_level)
3132

3233

3334
if OCI_RESOURCE_PRINCIPAL_VERSION:

ads/aqua/common/enums.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
5252
AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
5353
AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
5454
AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving"
55+
56+
57+
class CustomInferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
5558
AQUA_TEI_CONTAINER_FAMILY = "odsc-tei-serving"
5659

5760

ads/aqua/common/utils.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44
"""AQUA utils and constants."""
55

@@ -11,6 +11,7 @@
1111
import random
1212
import re
1313
import shlex
14+
import shutil
1415
import subprocess
1516
from datetime import datetime, timedelta
1617
from functools import wraps
@@ -21,6 +22,8 @@
2122
import fsspec
2223
import oci
2324
from cachetools import TTLCache, cached
25+
from huggingface_hub.constants import HF_HUB_CACHE
26+
from huggingface_hub.file_download import repo_folder_name
2427
from huggingface_hub.hf_api import HfApi, ModelInfo
2528
from huggingface_hub.utils import (
2629
GatedRepoError,
@@ -30,6 +33,7 @@
3033
)
3134
from oci.data_science.models import JobRun, Model
3235
from oci.object_storage.models import ObjectSummary
36+
from pydantic import ValidationError
3337

3438
from ads.aqua.common.enums import (
3539
InferenceContainerParamType,
@@ -788,7 +792,9 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
788792
return ocid[-key_len:] if ocid and len(ocid) > key_len else ""
789793

790794

791-
def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None) -> str:
795+
def upload_folder(
796+
os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None
797+
) -> str:
792798
"""Upload the local folder to the object storage
793799
794800
Args:
@@ -818,6 +824,48 @@ def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern
818824
return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path
819825

820826

827+
def cleanup_local_hf_model_artifact(
828+
model_name: str,
829+
local_dir: str = None,
830+
):
831+
"""
832+
Helper function that deletes local artifacts downloaded from Hugging Face to free up disk space.
833+
Parameters
834+
----------
835+
model_name (str): Name of the huggingface model
836+
local_dir (str): Local directory where the object is downloaded
837+
838+
"""
839+
if local_dir and os.path.exists(local_dir):
840+
model_dir = os.path.join(local_dir, model_name)
841+
model_dir = (
842+
os.path.dirname(model_dir)
843+
if "/" in model_name or os.sep in model_name
844+
else model_dir
845+
)
846+
shutil.rmtree(model_dir, ignore_errors=True)
847+
if os.path.exists(model_dir):
848+
logger.debug(
849+
f"Could not delete local model artifact directory: {model_dir}"
850+
)
851+
else:
852+
logger.debug(f"Deleted local model artifact directory: {model_dir}.")
853+
854+
hf_local_path = os.path.join(
855+
HF_HUB_CACHE, repo_folder_name(repo_id=model_name, repo_type="model")
856+
)
857+
shutil.rmtree(hf_local_path, ignore_errors=True)
858+
859+
if os.path.exists(hf_local_path):
860+
logger.debug(
861+
f"Could not clear the local Hugging Face cache directory {hf_local_path} for the model {model_name}."
862+
)
863+
else:
864+
logger.debug(
865+
f"Cleared contents of local Hugging Face cache directory {hf_local_path} for the model {model_name}."
866+
)
867+
868+
821869
def is_service_managed_container(container):
822870
return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME)
823871

@@ -1159,3 +1207,15 @@ def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]:
11591207

11601208
combined_cmd_var = cmd_var + overrides
11611209
return combined_cmd_var
1210+
1211+
1212+
def build_pydantic_error_message(ex: ValidationError):
1213+
"""Added to handle error messages from pydantic model validator.
1214+
Combine both loc and msg for errors where loc (field) is present in error details, else only build error
1215+
message using msg field."""
1216+
1217+
return {
1218+
".".join(map(str, e["loc"])): e["msg"]
1219+
for e in ex.errors()
1220+
if "loc" in e and e["loc"]
1221+
} or "; ".join(e["msg"] for e in ex.errors())

ads/aqua/data.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
3-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

6-
from dataclasses import dataclass, field
5+
from dataclasses import dataclass
76

87
from ads.common.serializer import DataClassSerializable
98

@@ -13,19 +12,3 @@ class AquaResourceIdentifier(DataClassSerializable):
1312
id: str = ""
1413
name: str = ""
1514
url: str = ""
16-
17-
18-
@dataclass(repr=False)
19-
class AquaJobSummary(DataClassSerializable):
20-
"""Represents an Aqua job summary."""
21-
22-
id: str
23-
name: str
24-
console_url: str
25-
lifecycle_state: str
26-
lifecycle_details: str
27-
time_created: str
28-
tags: dict
29-
experiment: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
30-
source: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
31-
job: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)

ads/aqua/extension/finetune_handler.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55

@@ -10,9 +10,7 @@
1010
from ads.aqua.common.decorator import handle_exceptions
1111
from ads.aqua.extension.base_handler import AquaAPIhandler
1212
from ads.aqua.extension.errors import Errors
13-
from ads.aqua.extension.utils import validate_function_parameters
1413
from ads.aqua.finetuning import AquaFineTuningApp
15-
from ads.aqua.finetuning.entities import CreateFineTuningDetails
1614

1715

1816
class AquaFineTuneHandler(AquaAPIhandler):
@@ -33,7 +31,7 @@ def get(self, id=""):
3331
raise HTTPError(400, f"The request {self.request.path} is invalid.")
3432

3533
@handle_exceptions
36-
def post(self, *args, **kwargs):
34+
def post(self, *args, **kwargs): # noqa: ARG002
3735
"""Handles post request for the fine-tuning API
3836
3937
Raises
@@ -43,17 +41,13 @@ def post(self, *args, **kwargs):
4341
"""
4442
try:
4543
input_data = self.get_json_body()
46-
except Exception:
47-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
44+
except Exception as ex:
45+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
4846

4947
if not input_data:
5048
raise HTTPError(400, Errors.NO_INPUT_DATA)
5149

52-
validate_function_parameters(
53-
data_class=CreateFineTuningDetails, input_data=input_data
54-
)
55-
56-
self.finish(AquaFineTuningApp().create(CreateFineTuningDetails(**input_data)))
50+
self.finish(AquaFineTuningApp().create(**input_data))
5751

5852
def get_finetuning_config(self, model_id):
5953
"""Gets the finetuning config for Aqua model."""
@@ -71,7 +65,7 @@ def get(self, model_id):
7165
)
7266

7367
@handle_exceptions
74-
def post(self, *args, **kwargs):
68+
def post(self, *args, **kwargs): # noqa: ARG002
7569
"""Handles post request for the finetuning param handler API.
7670
7771
Raises
@@ -81,8 +75,8 @@ def post(self, *args, **kwargs):
8175
"""
8276
try:
8377
input_data = self.get_json_body()
84-
except Exception:
85-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
78+
except Exception as ex:
79+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
8680

8781
if not input_data:
8882
raise HTTPError(400, Errors.NO_INPUT_DATA)

ads/aqua/extension/model_handler.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55
from typing import Optional
@@ -8,6 +8,9 @@
88
from tornado.web import HTTPError
99

1010
from ads.aqua.common.decorator import handle_exceptions
11+
from ads.aqua.common.enums import (
12+
CustomInferenceContainerTypeFamily,
13+
)
1114
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1215
from ads.aqua.common.utils import (
1316
get_hf_model_info,
@@ -128,6 +131,10 @@ def post(self, *args, **kwargs): # noqa: ARG002
128131
download_from_hf = (
129132
str(input_data.get("download_from_hf", "false")).lower() == "true"
130133
)
134+
local_dir = input_data.get("local_dir")
135+
cleanup_model_cache = (
136+
str(input_data.get("cleanup_model_cache", "true")).lower() == "true"
137+
)
131138
inference_container_uri = input_data.get("inference_container_uri")
132139
allow_patterns = input_data.get("allow_patterns")
133140
ignore_patterns = input_data.get("ignore_patterns")
@@ -139,6 +146,8 @@ def post(self, *args, **kwargs): # noqa: ARG002
139146
model=model,
140147
os_path=os_path,
141148
download_from_hf=download_from_hf,
149+
local_dir=local_dir,
150+
cleanup_model_cache=cleanup_model_cache,
142151
inference_container=inference_container,
143152
finetuning_container=finetuning_container,
144153
compartment_id=compartment_id,
@@ -163,7 +172,9 @@ def put(self, id):
163172
raise HTTPError(400, Errors.NO_INPUT_DATA)
164173

165174
inference_container = input_data.get("inference_container")
175+
inference_container_uri = input_data.get("inference_container_uri")
166176
inference_containers = AquaModelApp.list_valid_inference_containers()
177+
inference_containers.extend(CustomInferenceContainerTypeFamily.values())
167178
if (
168179
inference_container is not None
169180
and inference_container not in inference_containers
@@ -176,7 +187,13 @@ def put(self, id):
176187
task = input_data.get("task")
177188
app = AquaModelApp()
178189
self.finish(
179-
app.edit_registered_model(id, inference_container, enable_finetuning, task)
190+
app.edit_registered_model(
191+
id,
192+
inference_container,
193+
inference_container_uri,
194+
enable_finetuning,
195+
task,
196+
)
180197
)
181198
app.clear_model_details_cache(model_id=id)
182199

ads/aqua/finetuning/constants.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
3-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

65
from ads.common.extended_enum import ExtendedEnumMeta
@@ -17,4 +16,8 @@ class FineTuneCustomMetadata(str, metaclass=ExtendedEnumMeta):
1716
SERVICE_MODEL_FINE_TUNE_CONTAINER = "finetune-container"
1817

1918

19+
class FineTuningRestrictedParams(str, metaclass=ExtendedEnumMeta):
20+
OPTIMIZER = "optimizer"
21+
22+
2023
ENV_AQUA_FINE_TUNING_CONTAINER = "AQUA_FINE_TUNING_CONTAINER"

0 commit comments

Comments
 (0)