Skip to content

Commit 14b683c

Browse files
copy configs for verified models
1 parent 6bb395a commit 14b683c

File tree

3 files changed

+114
-10
lines changed

3 files changed

+114
-10
lines changed

ads/aqua/common/utils.py

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from ads.common.extended_enum import ExtendedEnumMeta
3434
from ads.common.object_storage_details import ObjectStorageDetails
3535
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
36-
from ads.common.utils import get_console_link, upload_to_os
36+
from ads.common.utils import get_console_link, upload_to_os, copy_file
3737
from ads.config import AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, TENANCY_OCID
3838
from ads.model import DataScienceModel, ModelVersionSet
3939

@@ -100,6 +100,23 @@ def get_status(evaluation_status: str, job_run_status: str = None):
100100
JobRun.LIFECYCLE_STATE_NEEDS_ATTENTION: "Missing jobrun information.",
101101
}
102102

103+
CONSOLE_LINK_RESOURCE_TYPE_MAPPING = dict(
104+
datasciencemodel="models",
105+
datasciencemodeldeployment="model-deployments",
106+
datasciencemodeldeploymentdev="model-deployments",
107+
datasciencemodeldeploymentint="model-deployments",
108+
datasciencemodeldeploymentpre="model-deployments",
109+
datasciencejob="jobs",
110+
datasciencejobrun="job-runs",
111+
datasciencejobrundev="job-runs",
112+
datasciencejobrunint="job-runs",
113+
datasciencejobrunpre="job-runs",
114+
datasciencemodelversionset="model-version-sets",
115+
datasciencemodelversionsetpre="model-version-sets",
116+
datasciencemodelversionsetint="model-version-sets",
117+
datasciencemodelversionsetdev="model-version-sets",
118+
)
119+
103120

104121
def random_color_generator(word: str):
105122
seed = sum([ord(c) for c in word]) % 13
@@ -227,12 +244,10 @@ def is_valid_ocid(ocid: str) -> bool:
227244
bool:
228245
Whether the given ocid is valid.
229246
"""
230-
# TODO: revisit pattern
231-
pattern = (
232-
r"^ocid1\.([a-z0-9_]+)\.([a-z0-9]+)\.([a-z0-9-]*)(\.[^.]+)?\.([a-z0-9_]+)$"
233-
)
234-
match = re.match(pattern, ocid)
235-
return True
247+
248+
if not ocid:
249+
return False
250+
return ocid.lower().startswith("ocid")
236251

237252

238253
def get_resource_type(ocid: str) -> str:
@@ -557,7 +572,7 @@ def fetch_service_compartment() -> Union[str, None]:
557572
config_file_name=CONTAINER_INDEX,
558573
)
559574
except Exception as e:
560-
logger.error(
575+
logger.debug(
561576
f"Config file {config_file_name}/{CONTAINER_INDEX} to fetch service compartment OCID could not be found. "
562577
f"\n{str(e)}."
563578
)
@@ -824,3 +839,51 @@ def get_combined_params(params1: str = None, params2: str = None) -> str:
824839
]
825840

826841
return " ".join(combined_params)
842+
843+
844+
def copy_model_config(artifact_path: str, os_path: str, auth: dict = None):
845+
"""Copies the aqua model config folder from the artifact path to the user provided object storage path.
846+
The config folder is overwritten if the files already exist at the destination path.
847+
848+
Parameters
849+
----------
850+
artifact_path:
851+
Path of the aqua model where config folder is available.
852+
os_path:
853+
User provided path where config folder will be copied.
854+
auth: (Dict, optional). Defaults to None.
855+
The default authentication is set using `ads.set_auth` API. If you need to override the
856+
default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
857+
authentication signer and kwargs required to instantiate IdentityClient object.
858+
859+
Returns
860+
-------
861+
None
862+
Nothing.
863+
"""
864+
865+
try:
866+
source_dir = ObjectStorageDetails(
867+
AQUA_SERVICE_MODELS_BUCKET,
868+
CONDA_BUCKET_NS,
869+
f"{os.path.dirname(artifact_path).rstrip('/')}/config",
870+
).path
871+
dest_dir = f"{os_path.rstrip('/')}/config"
872+
873+
oss_details = ObjectStorageDetails.from_path(source_dir)
874+
objects = oss_details.list_objects(fields="name").objects
875+
876+
for obj in objects:
877+
source_path = ObjectStorageDetails(
878+
AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, obj.name
879+
).path
880+
destination_path = os.path.join(dest_dir, os.path.basename(obj.name))
881+
copy_file(
882+
uri_src=source_path,
883+
uri_dst=destination_path,
884+
force_overwrite=True,
885+
auth=auth,
886+
)
887+
except Exception as ex:
888+
logger.debug(ex)
889+
logger.debug(f"Failed to copy config folder from {artifact_path} to {os_path}.")

ads/aqua/model/model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
is_service_managed_container,
2222
read_file,
2323
upload_folder,
24+
copy_model_config,
2425
)
2526
from ads.aqua.constants import (
2627
LICENSE_TXT,
@@ -660,11 +661,17 @@ def _create_model_catalog_entry(
660661

661662
try:
662663
# If verified model already has a artifact json, use that.
663-
metadata.get(MODEL_BY_REFERENCE_OSS_PATH_KEY)
664+
artifact_path = metadata.get(MODEL_BY_REFERENCE_OSS_PATH_KEY).value
664665
logger.info(
665666
f"Found model artifact in the service bucket. "
666667
f"Using artifact from service bucket instead of {os_path}"
667668
)
669+
670+
# copy model config from artifact path to user bucket
671+
copy_model_config(
672+
artifact_path=artifact_path, os_path=os_path, auth=self._auth
673+
)
674+
668675
except:
669676
# Add artifact from user bucket
670677
metadata.add(

tests/unitary/with_extras/aqua/test_model.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,12 +521,24 @@ def test_get_model_fine_tuned(
521521
"evaluation_container": "odsc-llm-evaluate",
522522
}
523523

524+
@pytest.mark.parametrize(
525+
"artifact_location_set",
526+
[
527+
True,
528+
False,
529+
],
530+
)
531+
@patch("ads.aqua.common.utils.copy_file")
532+
@patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects")
524533
@patch("huggingface_hub.snapshot_download")
525534
@patch("subprocess.check_call")
526535
def test_import_verified_model(
527536
self,
528537
mock_subprocess,
529538
mock_snapshot_download,
539+
mock_list_objects,
540+
mock_copy_file,
541+
artifact_location_set,
530542
):
531543
ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True)
532544
ads.common.oci_datascience.OCIDataScienceMixin.init_client = MagicMock()
@@ -536,6 +548,16 @@ def test_import_verified_model(
536548
DataScienceModel.sync = MagicMock()
537549
OCIDataScienceModel.create = MagicMock()
538550

551+
# The name attribute cannot be mocked during creation of the mock object,
552+
# hence attach it separately to the mocked objects.
553+
artifact_path = "service_models/model-name/commit-id/artifact"
554+
obj1 = MagicMock(etag="12345-1234-1234-1234-123456789", size=150)
555+
obj1.name = f"{artifact_path}/config/deployment_config.json"
556+
obj2 = MagicMock(etag="12345-1234-1234-1234-123456789", size=150)
557+
obj2.name = f"{artifact_path}/config/ft_config.json"
558+
objects = [obj1, obj2]
559+
mock_list_objects.return_value = MagicMock(objects=objects)
560+
539561
ds_model = DataScienceModel()
540562
os_path = "oci://aqua-bkt@aqua-ns/prefix/path"
541563
hf_model = "oracle/aqua-1t-mega-model"
@@ -562,6 +584,14 @@ def test_import_verified_model(
562584
custom_metadata_list.add(
563585
**{"key": "evaluation-container", "value": "odsc-llm-evaluate"}
564586
)
587+
if not artifact_location_set:
588+
custom_metadata_list.add(
589+
**{
590+
"key": "artifact_location",
591+
"value": artifact_path,
592+
"description": "artifact location",
593+
}
594+
)
565595
ds_model.with_custom_metadata_list(custom_metadata_list)
566596
ds_model.set_spec(ds_model.CONST_MODEL_FILE_DESCRIPTION, {})
567597
ds_model.dsc_model = MagicMock(id="test_model_id")
@@ -579,6 +609,8 @@ def test_import_verified_model(
579609
local_dir=f"{str(tmpdir)}/{hf_model}",
580610
local_dir_use_symlinks=False,
581611
)
612+
if not artifact_location_set:
613+
mock_copy_file.assert_called()
582614
mock_subprocess.assert_called_with(
583615
shlex.split(
584616
f"oci os object bulk-upload --src-dir {str(tmpdir)}/{hf_model} --prefix prefix/path/{hf_model}/ -bn aqua-bkt -ns aqua-ns --auth api_key --profile DEFAULT"
@@ -613,7 +645,9 @@ def test_import_verified_model(
613645
},
614646
{
615647
"key": "artifact_location",
616-
"value": f"{os_path}/{hf_model}/",
648+
"value": f"{os_path}/{hf_model}/"
649+
if artifact_location_set
650+
else artifact_path,
617651
"description": "artifact location",
618652
"category": "Other",
619653
},

0 commit comments

Comments
 (0)