Skip to content

Commit 0dbb028

Browse files
authored
Merge branch 'main' into ahosler-patch-1
2 parents 8639a93 + ccabc04 commit 0dbb028

File tree

12 files changed

+201
-54
lines changed

12 files changed

+201
-54
lines changed

ads/aqua/app.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import json
66
import os
7+
import traceback
78
from dataclasses import fields
89
from typing import Dict, Union
910

@@ -23,7 +24,7 @@
2324
from ads.aqua.constants import UNKNOWN
2425
from ads.common import oci_client as oc
2526
from ads.common.auth import default_signer
26-
from ads.common.utils import extract_region
27+
from ads.common.utils import extract_region, is_path_exists
2728
from ads.config import (
2829
AQUA_TELEMETRY_BUCKET,
2930
AQUA_TELEMETRY_BUCKET_NS,
@@ -296,33 +297,44 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict:
296297
raise AquaRuntimeError(f"Target model {oci_model.id} is not Aqua model.")
297298

298299
config = {}
299-
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
300+
# if the current model has a service model tag, then
301+
if Tags.AQUA_SERVICE_MODEL_TAG in oci_model.freeform_tags:
302+
base_model_ocid = oci_model.freeform_tags[Tags.AQUA_SERVICE_MODEL_TAG]
303+
logger.info(
304+
f"Base model found for the model: {oci_model.id}. "
305+
f"Loading {config_file_name} for base model {base_model_ocid}."
306+
)
307+
base_model = self.ds_client.get_model(base_model_ocid).data
308+
artifact_path = get_artifact_path(base_model.custom_metadata_list)
309+
config_path = f"{os.path.dirname(artifact_path)}/config/"
310+
else:
311+
logger.info(f"Loading {config_file_name} for model {oci_model.id}...")
312+
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
313+
config_path = f"{artifact_path.rstrip('/')}/config/"
314+
300315
if not artifact_path:
301316
logger.debug(
302317
f"Failed to get artifact path from custom metadata for the model: {model_id}"
303318
)
304319
return config
305320

306-
try:
307-
config_path = f"{os.path.dirname(artifact_path)}/config/"
308-
config = load_config(
309-
config_path,
310-
config_file_name=config_file_name,
311-
)
312-
except Exception:
313-
# todo: temp fix for issue related to config load for byom models, update logic to choose the right path
321+
config_file_path = f"{config_path}{config_file_name}"
322+
if is_path_exists(config_file_path):
314323
try:
315-
config_path = f"{artifact_path.rstrip('/')}/config/"
316324
config = load_config(
317325
config_path,
318326
config_file_name=config_file_name,
319327
)
320328
except Exception:
321-
pass
329+
logger.debug(
330+
f"Error loading the {config_file_name} at path {config_path}.\n"
331+
f"{traceback.format_exc()}"
332+
)
322333

323334
if not config:
324-
logger.error(
325-
f"{config_file_name} is not available for the model: {model_id}. Check if the custom metadata has the artifact path set."
335+
logger.debug(
336+
f"{config_file_name} is not available for the model: {model_id}. "
337+
f"Check if the custom metadata has the artifact path set."
326338
)
327339
return config
328340

ads/aqua/extension/model_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def post(self, *args, **kwargs): # noqa: ARG002
133133
)
134134
local_dir = input_data.get("local_dir")
135135
cleanup_model_cache = (
136-
str(input_data.get("cleanup_model_cache", "true")).lower() == "true"
136+
str(input_data.get("cleanup_model_cache", "false")).lower() == "true"
137137
)
138138
inference_container_uri = input_data.get("inference_container_uri")
139139
allow_patterns = input_data.get("allow_patterns")

ads/aqua/finetuning/entities.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ class CreateFineTuningDetails(Serializable):
122122
The log group id for fine tuning job infrastructure.
123123
log_id: (str, optional). Defaults to `None`.
124124
The log id for fine tuning job infrastructure.
125+
watch_logs: (bool, optional). Defaults to `False`.
126+
The flag to watch the job run logs when a fine-tuning job is created.
125127
force_overwrite: (bool, optional). Defaults to `False`.
126128
Whether to force overwrite the existing file in object storage.
127129
freeform_tags: (dict, optional)
@@ -148,6 +150,7 @@ class CreateFineTuningDetails(Serializable):
148150
subnet_id: Optional[str] = None
149151
log_id: Optional[str] = None
150152
log_group_id: Optional[str] = None
153+
watch_logs: Optional[bool] = False
151154
force_overwrite: Optional[bool] = False
152155
freeform_tags: Optional[dict] = None
153156
defined_tags: Optional[dict] = None

ads/aqua/finetuning/finetuning.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import json
66
import os
7+
import time
8+
import traceback
79
from typing import Dict
810

911
from oci.data_science.models import (
@@ -149,6 +151,15 @@ def create(
149151
f"Logging is required for fine tuning if replica is larger than {DEFAULT_FT_REPLICA}."
150152
)
151153

154+
if create_fine_tuning_details.watch_logs and not (
155+
create_fine_tuning_details.log_id
156+
and create_fine_tuning_details.log_group_id
157+
):
158+
raise AquaValueError(
159+
"Logging is required for fine tuning if watch_logs is set to True. "
160+
"Please provide log_id and log_group_id with the request parameters."
161+
)
162+
152163
ft_parameters = self._get_finetuning_params(
153164
create_fine_tuning_details.ft_parameters
154165
)
@@ -422,6 +433,20 @@ def create(
422433
value=source.display_name,
423434
)
424435

436+
if create_fine_tuning_details.watch_logs:
437+
logger.info(
438+
f"Watching fine-tuning job run logs for {ft_job_run.id}. Press Ctrl+C to stop watching logs.\n"
439+
)
440+
try:
441+
ft_job_run.watch()
442+
except KeyboardInterrupt:
443+
logger.info(f"\nStopped watching logs for {ft_job_run.id}.\n")
444+
time.sleep(1)
445+
except Exception:
446+
logger.debug(
447+
f"Something unexpected occurred while watching logs.\n{traceback.format_exc()}"
448+
)
449+
425450
return AquaFineTuningSummary(
426451
id=ft_model.id,
427452
name=ft_model.display_name,

ads/aqua/model/entities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ class ImportModelDetails(CLIBuilderMixin):
283283
os_path: str
284284
download_from_hf: Optional[bool] = True
285285
local_dir: Optional[str] = None
286-
cleanup_model_cache: Optional[bool] = True
286+
cleanup_model_cache: Optional[bool] = False
287287
inference_container: Optional[str] = None
288288
finetuning_container: Optional[str] = None
289289
compartment_id: Optional[str] = None

ads/aqua/model/model.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
LifecycleStatus,
3030
_build_resource_identifier,
3131
cleanup_local_hf_model_artifact,
32-
copy_model_config,
3332
create_word_icon,
3433
generate_tei_cmd_var,
3534
get_artifact_path,
@@ -969,24 +968,6 @@ def _create_model_catalog_entry(
969968
)
970969
tags[Tags.LICENSE] = validation_result.tags.get(Tags.LICENSE, UNKNOWN)
971970

972-
try:
973-
# If verified model already has a artifact json, use that.
974-
artifact_path = metadata.get(MODEL_BY_REFERENCE_OSS_PATH_KEY).value
975-
logger.info(
976-
f"Found model artifact in the service bucket. "
977-
f"Using artifact from service bucket instead of {os_path}."
978-
)
979-
980-
# todo: implement generic copy_folder method
981-
# copy model config from artifact path to user bucket
982-
copy_model_config(
983-
artifact_path=artifact_path, os_path=os_path, auth=default_signer()
984-
)
985-
except Exception:
986-
logger.debug(
987-
f"Proceeding with model registration without copying model config files at {os_path}. "
988-
f"Default configuration will be used for deployment and fine-tuning."
989-
)
990971
# Set artifact location to user bucket, and replace existing key if present.
991972
metadata.add(
992973
key=MODEL_BY_REFERENCE_OSS_PATH_KEY,

ads/cli.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
#!/usr/bin/env python
2-
32
# Copyright (c) 2021, 2025 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

5+
import json
66
import logging
77
import sys
88
import traceback
99
import uuid
1010

1111
import fire
12+
from pydantic import BaseModel
1213

1314
from ads.common import logger
1415

@@ -84,7 +85,13 @@ def serialize(data):
8485
The string representation of each dataclass object.
8586
"""
8687
if isinstance(data, list):
87-
[print(str(item)) for item in data]
88+
for item in data:
89+
if isinstance(item, BaseModel):
90+
print(json.dumps(item.dict(), indent=4))
91+
else:
92+
print(str(item))
93+
elif isinstance(data, BaseModel):
94+
print(json.dumps(data.dict(), indent=4))
8895
else:
8996
print(str(data))
9097

ads/opctl/operator/lowcode/common/transformations.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
InvalidParameterError,
1616
)
1717
from ads.opctl.operator.lowcode.common.utils import merge_category_columns
18+
from ads.opctl.operator.lowcode.forecast.operator_config import ForecastOperatorSpec
1819

1920

2021
class Transformations(ABC):
@@ -34,6 +35,7 @@ def __init__(self, dataset_info, name="historical_data"):
3435
self.dataset_info = dataset_info
3536
self.target_category_columns = dataset_info.target_category_columns
3637
self.target_column_name = dataset_info.target_column
38+
self.raw_column_names = None
3739
self.dt_column_name = (
3840
dataset_info.datetime_column.name if dataset_info.datetime_column else None
3941
)
@@ -60,7 +62,8 @@ def run(self, data):
6062
6163
"""
6264
clean_df = self._remove_trailing_whitespace(data)
63-
# clean_df = self._normalize_column_names(clean_df)
65+
if isinstance(self.dataset_info, ForecastOperatorSpec):
66+
clean_df = self._clean_column_names(clean_df)
6467
if self.name == "historical_data":
6568
self._check_historical_dataset(clean_df)
6669
clean_df = self._set_series_id_column(clean_df)
@@ -98,8 +101,36 @@ def run(self, data):
98101
def _remove_trailing_whitespace(self, df):
99102
return df.apply(lambda x: x.str.strip() if x.dtype == "object" else x)
100103

101-
# def _normalize_column_names(self, df):
102-
# return df.rename(columns=lambda x: re.sub("[^A-Za-z0-9_]+", "", x))
104+
def _clean_column_names(self, df):
105+
"""
106+
Remove all whitespaces from column names in a DataFrame and store the original names.
107+
108+
Parameters:
109+
df (pd.DataFrame): The DataFrame whose column names need to be cleaned.
110+
111+
Returns:
112+
pd.DataFrame: The DataFrame with cleaned column names.
113+
"""
114+
115+
self.raw_column_names = {
116+
col: col.replace(" ", "") for col in df.columns if " " in col
117+
}
118+
df.columns = [self.raw_column_names.get(col, col) for col in df.columns]
119+
120+
if self.target_column_name:
121+
self.target_column_name = self.raw_column_names.get(
122+
self.target_column_name, self.target_column_name
123+
)
124+
self.dt_column_name = self.raw_column_names.get(
125+
self.dt_column_name, self.dt_column_name
126+
)
127+
128+
if self.target_category_columns:
129+
self.target_category_columns = [
130+
self.raw_column_names.get(col, col)
131+
for col in self.target_category_columns
132+
]
133+
return df
103134

104135
def _set_series_id_column(self, df):
105136
self._target_category_columns_map = {}
@@ -233,6 +264,10 @@ def _check_historical_dataset(self, df):
233264
expected_names = [self.target_column_name, self.dt_column_name] + (
234265
self.target_category_columns if self.target_category_columns else []
235266
)
267+
268+
if self.raw_column_names:
269+
expected_names.extend(list(self.raw_column_names.values()))
270+
236271
if set(df.columns) != set(expected_names):
237272
raise DataMismatchError(
238273
f"Expected {self.name} to have columns: {expected_names}, but instead found column names: {df.columns}. Is the {self.name} path correct?"

tests/unitary/with_extras/aqua/test_config.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55
import json
66
import os
7-
from unittest.mock import patch
7+
import pytest
8+
from unittest.mock import patch, MagicMock
9+
10+
import oci.data_science.models
811

912
from ads.aqua.common.entities import ContainerSpec
1013
from ads.aqua.config.config import get_evaluation_service_config
14+
from ads.aqua.app import AquaApp
1115

1216

1317
class TestConfig:
@@ -37,3 +41,63 @@ def test_evaluation_service_config(self, mock_get_container_config):
3741
test_result.to_dict()
3842
== expected_result[ContainerSpec.CONTAINER_SPEC]["test_container"]
3943
)
44+
45+
@pytest.mark.parametrize(
46+
"custom_metadata",
47+
[
48+
{
49+
"category": "Other",
50+
"description": "test_desc",
51+
"key": "artifact_location",
52+
"value": "artifact_location",
53+
},
54+
{},
55+
],
56+
)
57+
@pytest.mark.parametrize("verified_model", [True, False])
58+
@pytest.mark.parametrize("path_exists", [True, False])
59+
@patch("ads.aqua.app.load_config")
60+
def test_load_config(
61+
self, mock_load_config, custom_metadata, verified_model, path_exists
62+
):
63+
mock_load_config.return_value = {"config_key": "config_value"}
64+
service_model_tag = (
65+
{"aqua_service_model": "aqua_service_model_id"} if verified_model else {}
66+
)
67+
68+
self.app = AquaApp()
69+
70+
model = {
71+
"id": "mock_id",
72+
"lifecycle_details": "mock_lifecycle_details",
73+
"lifecycle_state": "mock_lifecycle_state",
74+
"project_id": "mock_project_id",
75+
"freeform_tags": {
76+
**{
77+
"OCI_AQUA": "",
78+
},
79+
**service_model_tag,
80+
},
81+
"custom_metadata_list": [
82+
oci.data_science.models.Metadata(**custom_metadata)
83+
],
84+
}
85+
86+
self.app.ds_client.get_model = MagicMock(
87+
return_value=oci.response.Response(
88+
status=200,
89+
request=MagicMock(),
90+
headers=MagicMock(),
91+
data=oci.data_science.models.Model(**model),
92+
)
93+
)
94+
with patch("ads.aqua.app.is_path_exists", return_value=path_exists):
95+
result = self.app.get_config(
96+
model_id="test_model_id", config_file_name="test_config_file_name"
97+
)
98+
if not path_exists:
99+
assert result == {}
100+
if not custom_metadata:
101+
assert result == {}
102+
if path_exists and custom_metadata:
103+
assert result == {"config_key": "config_value"}

0 commit comments

Comments
 (0)