Skip to content

Commit 2cf5b51

Browse files
[ODSC-66853] Load base model config by default (#1053)
2 parents 0cc2b58 + 33add8e commit 2cf5b51

File tree

4 files changed

+92
-39
lines changed

4 files changed

+92
-39
lines changed

ads/aqua/app.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import json
66
import os
7+
import traceback
78
from dataclasses import fields
89
from typing import Dict, Union
910

@@ -23,7 +24,7 @@
2324
from ads.aqua.constants import UNKNOWN
2425
from ads.common import oci_client as oc
2526
from ads.common.auth import default_signer
26-
from ads.common.utils import extract_region
27+
from ads.common.utils import extract_region, is_path_exists
2728
from ads.config import (
2829
AQUA_TELEMETRY_BUCKET,
2930
AQUA_TELEMETRY_BUCKET_NS,
@@ -296,33 +297,44 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict:
296297
raise AquaRuntimeError(f"Target model {oci_model.id} is not Aqua model.")
297298

298299
config = {}
299-
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
300+
# if the current model has a service model tag, then
301+
if Tags.AQUA_SERVICE_MODEL_TAG in oci_model.freeform_tags:
302+
base_model_ocid = oci_model.freeform_tags[Tags.AQUA_SERVICE_MODEL_TAG]
303+
logger.info(
304+
f"Base model found for the model: {oci_model.id}. "
305+
f"Loading {config_file_name} for base model {base_model_ocid}."
306+
)
307+
base_model = self.ds_client.get_model(base_model_ocid).data
308+
artifact_path = get_artifact_path(base_model.custom_metadata_list)
309+
config_path = f"{os.path.dirname(artifact_path)}/config/"
310+
else:
311+
logger.info(f"Loading {config_file_name} for model {oci_model.id}...")
312+
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
313+
config_path = f"{artifact_path.rstrip('/')}/config/"
314+
300315
if not artifact_path:
301316
logger.debug(
302317
f"Failed to get artifact path from custom metadata for the model: {model_id}"
303318
)
304319
return config
305320

306-
try:
307-
config_path = f"{os.path.dirname(artifact_path)}/config/"
308-
config = load_config(
309-
config_path,
310-
config_file_name=config_file_name,
311-
)
312-
except Exception:
313-
# todo: temp fix for issue related to config load for byom models, update logic to choose the right path
321+
config_file_path = f"{config_path}{config_file_name}"
322+
if is_path_exists(config_file_path):
314323
try:
315-
config_path = f"{artifact_path.rstrip('/')}/config/"
316324
config = load_config(
317325
config_path,
318326
config_file_name=config_file_name,
319327
)
320328
except Exception:
321-
pass
329+
logger.debug(
330+
f"Error loading the {config_file_name} at path {config_path}.\n"
331+
f"{traceback.format_exc()}"
332+
)
322333

323334
if not config:
324-
logger.error(
325-
f"{config_file_name} is not available for the model: {model_id}. Check if the custom metadata has the artifact path set."
335+
logger.debug(
336+
f"{config_file_name} is not available for the model: {model_id}. "
337+
f"Check if the custom metadata has the artifact path set."
326338
)
327339
return config
328340

ads/aqua/model/model.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
LifecycleStatus,
3030
_build_resource_identifier,
3131
cleanup_local_hf_model_artifact,
32-
copy_model_config,
3332
create_word_icon,
3433
generate_tei_cmd_var,
3534
get_artifact_path,
@@ -969,24 +968,6 @@ def _create_model_catalog_entry(
969968
)
970969
tags[Tags.LICENSE] = validation_result.tags.get(Tags.LICENSE, UNKNOWN)
971970

972-
try:
973-
# If verified model already has a artifact json, use that.
974-
artifact_path = metadata.get(MODEL_BY_REFERENCE_OSS_PATH_KEY).value
975-
logger.info(
976-
f"Found model artifact in the service bucket. "
977-
f"Using artifact from service bucket instead of {os_path}."
978-
)
979-
980-
# todo: implement generic copy_folder method
981-
# copy model config from artifact path to user bucket
982-
copy_model_config(
983-
artifact_path=artifact_path, os_path=os_path, auth=default_signer()
984-
)
985-
except Exception:
986-
logger.debug(
987-
f"Proceeding with model registration without copying model config files at {os_path}. "
988-
f"Default configuration will be used for deployment and fine-tuning."
989-
)
990971
# Set artifact location to user bucket, and replace existing key if present.
991972
metadata.add(
992973
key=MODEL_BY_REFERENCE_OSS_PATH_KEY,

tests/unitary/with_extras/aqua/test_config.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55
import json
66
import os
7-
from unittest.mock import patch
7+
import pytest
8+
from unittest.mock import patch, MagicMock
9+
10+
import oci.data_science.models
811

912
from ads.aqua.common.entities import ContainerSpec
1013
from ads.aqua.config.config import get_evaluation_service_config
14+
from ads.aqua.app import AquaApp
1115

1216

1317
class TestConfig:
@@ -37,3 +41,63 @@ def test_evaluation_service_config(self, mock_get_container_config):
3741
test_result.to_dict()
3842
== expected_result[ContainerSpec.CONTAINER_SPEC]["test_container"]
3943
)
44+
45+
@pytest.mark.parametrize(
46+
"custom_metadata",
47+
[
48+
{
49+
"category": "Other",
50+
"description": "test_desc",
51+
"key": "artifact_location",
52+
"value": "artifact_location",
53+
},
54+
{},
55+
],
56+
)
57+
@pytest.mark.parametrize("verified_model", [True, False])
58+
@pytest.mark.parametrize("path_exists", [True, False])
59+
@patch("ads.aqua.app.load_config")
60+
def test_load_config(
61+
self, mock_load_config, custom_metadata, verified_model, path_exists
62+
):
63+
mock_load_config.return_value = {"config_key": "config_value"}
64+
service_model_tag = (
65+
{"aqua_service_model": "aqua_service_model_id"} if verified_model else {}
66+
)
67+
68+
self.app = AquaApp()
69+
70+
model = {
71+
"id": "mock_id",
72+
"lifecycle_details": "mock_lifecycle_details",
73+
"lifecycle_state": "mock_lifecycle_state",
74+
"project_id": "mock_project_id",
75+
"freeform_tags": {
76+
**{
77+
"OCI_AQUA": "",
78+
},
79+
**service_model_tag,
80+
},
81+
"custom_metadata_list": [
82+
oci.data_science.models.Metadata(**custom_metadata)
83+
],
84+
}
85+
86+
self.app.ds_client.get_model = MagicMock(
87+
return_value=oci.response.Response(
88+
status=200,
89+
request=MagicMock(),
90+
headers=MagicMock(),
91+
data=oci.data_science.models.Model(**model),
92+
)
93+
)
94+
with patch("ads.aqua.app.is_path_exists", return_value=path_exists):
95+
result = self.app.get_config(
96+
model_id="test_model_id", config_file_name="test_config_file_name"
97+
)
98+
if not path_exists:
99+
assert result == {}
100+
if not custom_metadata:
101+
assert result == {}
102+
if path_exists and custom_metadata:
103+
assert result == {"config_key": "config_value"}

tests/unitary/with_extras/aqua/test_model.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,6 @@ def test_get_model_fine_tuned(
665665
@patch("ads.model.datascience_model.DataScienceModel.sync")
666666
@patch("ads.model.datascience_model.DataScienceModel.upload_artifact")
667667
@patch.object(AquaModelApp, "_find_matching_aqua_model")
668-
@patch("ads.aqua.common.utils.copy_file")
669668
@patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects")
670669
@patch("ads.aqua.common.utils.load_config", return_value={})
671670
@patch("huggingface_hub.snapshot_download")
@@ -676,7 +675,6 @@ def test_import_verified_model(
676675
mock_snapshot_download,
677676
mock_load_config,
678677
mock_list_objects,
679-
mock_copy_file,
680678
mock__find_matching_aqua_model,
681679
mock_upload_artifact,
682680
mock_sync,
@@ -788,8 +786,6 @@ def test_import_verified_model(
788786
mock_subprocess.assert_not_called()
789787
mock_load_config.assert_called()
790788

791-
if not artifact_location_set:
792-
mock_copy_file.assert_called()
793789
ds_freeform_tags.pop(
794790
"ready_to_import"
795791
) # The imported model should not have this tag

0 commit comments

Comments
 (0)