Skip to content

Commit 9d2e65b

Browse files
Minor changes for Verified Models (#856)
2 parents 2a24441 + e27acc0 commit 9d2e65b

File tree

3 files changed

+39
-48
lines changed

3 files changed

+39
-48
lines changed

ads/aqua/extension/model_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def post(self, *args, **kwargs):
189189
"Please select a model with a compatible pipeline tag."
190190
)
191191

192-
# Check if it is a service/shadow model
192+
# Check if it is a service/verified model
193193
aqua_model_info: AquaModelSummary = self._find_matching_aqua_model(
194194
model_id=hf_model_info.id
195195
)

ads/aqua/model/model.py

Lines changed: 26 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,12 @@
77
from threading import Lock
88
from typing import List, Optional, Union
99

10-
import oci
1110
from cachetools import TTLCache
1211
from huggingface_hub import HfApi, snapshot_download
1312
from oci.data_science.models import JobRun, Model
1413

15-
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger
14+
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
1615
from ads.aqua.app import AquaApp
17-
from ads.aqua.common import utils
1816
from ads.aqua.common.enums import Tags
1917
from ads.aqua.common.errors import AquaRuntimeError
2018
from ads.aqua.common.utils import (
@@ -37,11 +35,8 @@
3735
VALIDATION_METRICS,
3836
VALIDATION_METRICS_FINAL,
3937
)
40-
from ads.aqua.data import AquaResourceIdentifier
4138
from ads.aqua.model.constants import *
4239
from ads.aqua.model.entities import *
43-
from ads.aqua.model.enums import FineTuningDefinedMetadata
44-
from ads.aqua.training.exceptions import exit_code_dict
4540
from ads.common.auth import default_signer
4641
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
4742
from ads.common.utils import get_console_link
@@ -186,7 +181,7 @@ def get(self, model_id: str, load_model_card: Optional[bool] = True) -> "AquaMod
186181
)
187182

188183
# todo: consolidate this logic in utils for model and deployment use
189-
is_shadow_type = (
184+
is_verified_type = (
190185
ds_model.freeform_tags.get(Tags.READY_TO_IMPORT, "false").upper()
191186
== READY_TO_IMPORT_STATUS
192187
)
@@ -201,7 +196,7 @@ def get(self, model_id: str, load_model_card: Optional[bool] = True) -> "AquaMod
201196
read_file(
202197
file_path=(
203198
f"{artifact_path.rstrip('/')}/config/{README}"
204-
if is_shadow_type
199+
if is_verified_type
205200
else f"{artifact_path.rstrip('/')}/{README}"
206201
),
207202
auth=self._auth,
@@ -564,7 +559,7 @@ def _create_model_catalog_entry(
564559
finetuning_container: str,
565560
inference_container_type_smc: bool,
566561
finetuning_container_type_smc: bool,
567-
shadow_model: DataScienceModel,
562+
verified_model: DataScienceModel,
568563
compartment_id: Optional[str],
569564
project_id: Optional[str],
570565
) -> DataScienceModel:
@@ -577,7 +572,7 @@ def _create_model_catalog_entry(
577572
inference_container_type_smc (bool): If true, then `inference_contianer` argument should contain service managed container name without tag information
578573
finetuning_container (str): selects service defaults
579574
finetuning_container_type_smc (bool): If true, then `finetuning_container` argument should contain service managed container name without tag
580-
shadow_model (DataScienceModel): If set, then copies all the tags and custom metadata information from the service shadow model
575+
verified_model (DataScienceModel): If set, then copies all the tags and custom metadata information from the service verified model
581576
compartment_id (Optional[str]): Compartment Id of the compartment where the model has to be created
582577
project_id (Optional[str]): Project id of the project where the model has to be created
583578
@@ -592,30 +587,29 @@ def _create_model_catalog_entry(
592587
except Exception:
593588
logger.exception(f"Could not fetch model information for {model_name}")
594589
tags = (
595-
{**shadow_model.freeform_tags, Tags.AQUA_SERVICE_MODEL_TAG: shadow_model.id}
596-
if shadow_model
590+
{
591+
**verified_model.freeform_tags,
592+
Tags.AQUA_SERVICE_MODEL_TAG: verified_model.id,
593+
}
594+
if verified_model
597595
else {Tags.AQUA_TAG: "active", Tags.BASE_MODEL_CUSTOM: "true"}
598596
)
599597
tags.update({Tags.BASE_MODEL_CUSTOM: "true"})
600598

601599
# Remove `ready_to_import` tag that might get copied from service model.
602600
tags.pop(Tags.READY_TO_IMPORT, None)
603601
metadata = None
604-
if shadow_model:
605-
# Shadow model is a model in the service catalog that either has no artifacts but contains all the necessary metadata for deploying and fine tuning.
602+
if verified_model:
603+
# Verified model is a model in the service catalog that either has no artifacts but contains all the necessary metadata for deploying and fine tuning.
606604
# If set, then we copy all the model metadata.
607-
metadata = shadow_model.custom_metadata_list
608-
if shadow_model.model_file_description:
605+
metadata = verified_model.custom_metadata_list
606+
if verified_model.model_file_description:
609607
model = model.with_model_file_description(
610-
json_dict=shadow_model.model_file_description
608+
json_dict=verified_model.model_file_description
611609
)
612610

613611
else:
614612
metadata = ModelCustomMetadata()
615-
if not inference_container:
616-
raise ValueError(
617-
f"Require Inference container information. Model: {model_name} does not have associated inference container defaults. Check docs for more information on how to pass inference container"
618-
)
619613
if finetuning_container:
620614
tags[Tags.AQUA_FINE_TUNING] = "true"
621615
metadata.add(
@@ -665,7 +659,7 @@ def _create_model_catalog_entry(
665659
)
666660

667661
try:
668-
# If shadow model already has a artifact json, use that.
662+
# If verified model already has a artifact json, use that.
669663
metadata.get(MODEL_BY_REFERENCE_OSS_PATH_KEY)
670664
logger.info(
671665
f"Found model artifact in the service bucket. "
@@ -685,7 +679,7 @@ def _create_model_catalog_entry(
685679
.with_compartment_id(compartment_id or COMPARTMENT_OCID)
686680
.with_project_id(project_id or PROJECT_OCID)
687681
.with_artifact(os_path)
688-
.with_display_name(os.path.basename(model_name))
682+
.with_display_name(model_name)
689683
.with_freeform_tags(**tags)
690684
).create(model_by_reference=True)
691685
logger.debug(model)
@@ -712,7 +706,7 @@ def register(
712706
Returns:
713707
str: Model ID of the registered model
714708
"""
715-
shadow_model_details: DataScienceModel = None
709+
verified_model_details: DataScienceModel = None
716710

717711
if not import_model_details:
718712
import_model_details = ImportModelDetails(**kwargs)
@@ -734,22 +728,22 @@ def register(
734728
f"Found service model for {import_model_details.model}: {model_service_id}"
735729
)
736730
if model_service_id:
737-
shadow_model_details = DataScienceModel.from_id(model_service_id)
738-
inference_container = shadow_model_details.custom_metadata_list.get(
731+
verified_model_details = DataScienceModel.from_id(model_service_id)
732+
inference_container = verified_model_details.custom_metadata_list.get(
739733
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME
740734
).value
741735
try:
742736
# No Default finetuning container
743-
finetuning_container = shadow_model_details.custom_metadata_list.get(
737+
finetuning_container = verified_model_details.custom_metadata_list.get(
744738
AQUA_FINETUNING_CONTAINER_METADATA_NAME
745739
).value
746740
except:
747741
pass
748742

749743
# Copy the model name from the service model if `model` is ocid
750744
model_name = (
751-
shadow_model_details.display_name
752-
if shadow_model_details
745+
verified_model_details.display_name
746+
if verified_model_details
753747
else import_model_details.model
754748
)
755749

@@ -777,7 +771,8 @@ def register(
777771
os.makedirs(local_dir, exist_ok=True)
778772
# Copy the model from the cache to destination
779773
snapshot_download(
780-
repo_id=model_name, local_dir=local_dir, local_dir_use_symlinks=False
774+
repo_id=model_name,
775+
local_dir=local_dir,
781776
)
782777
# Upload to object storage
783778
model_artifact_path = upload_folder(
@@ -805,7 +800,7 @@ def register(
805800
)
806801
else import_model_details.finetuning_container_type_smc
807802
),
808-
shadow_model=shadow_model_details,
803+
verified_model=verified_model_details,
809804
compartment_id=import_model_details.compartment_id,
810805
project_id=import_model_details.project_id,
811806
)

tests/unitary/with_extras/aqua/test_model.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77
import os
88
import shlex
99
import tempfile
10-
import unittest
1110
from dataclasses import asdict
1211
from importlib import reload
13-
from unittest.mock import MagicMock, PropertyMock, patch
12+
from unittest.mock import MagicMock, patch
1413

1514
import huggingface_hub
1615
import oci
@@ -70,7 +69,7 @@ class TestDataset:
7069
"compartment_id": "ocid1.compartment.oc1..<OCID>",
7170
"created_by": "ocid1.datasciencenotebooksession.oc1.iad.<OCID>",
7271
"defined_tags": {},
73-
"display_name": "ShadowModel",
72+
"display_name": "VerifiedModel",
7473
"freeform_tags": {
7574
"OCI_AQUA": "",
7675
"license": "UPL",
@@ -243,7 +242,7 @@ def test_create_model(self, mock_from_id, mock_validate, mock_create):
243242
"foundation_model_type",
244243
[
245244
"service",
246-
"shadow",
245+
"verified",
247246
],
248247
)
249248
@patch("ads.aqua.model.model.read_file")
@@ -266,12 +265,12 @@ def test_get_foundation_models(
266265
ds_model.display_name = "test_display_name"
267266
ds_model.description = "test_description"
268267
ds_model.freeform_tags = {
269-
"OCI_AQUA": "" if foundation_model_type == "shadow" else "ACTIVE",
268+
"OCI_AQUA": "" if foundation_model_type == "verified" else "ACTIVE",
270269
"license": "test_license",
271270
"organization": "test_organization",
272271
"task": "test_task",
273272
}
274-
if foundation_model_type == "shadow":
273+
if foundation_model_type == "verified":
275274
ds_model.freeform_tags["ready_to_import"] = "true"
276275
ds_model.time_created = "2024-01-19T17:57:39.158000+00:00"
277276
custom_metadata_list = ModelCustomMetadata()
@@ -305,15 +304,15 @@ def test_get_foundation_models(
305304
mock_read_file.return_value = "test_model_card"
306305

307306
model_id = (
308-
"shadow_model_id"
309-
if foundation_model_type == "shadow"
307+
"verified_model_id"
308+
if foundation_model_type == "verified"
310309
else "service_model_id"
311310
)
312311
aqua_model = self.app.get(model_id=model_id)
313312

314313
mock_from_id.assert_called_with(model_id)
315314

316-
if foundation_model_type == "shadow":
315+
if foundation_model_type == "verified":
317316
mock_read_file.assert_called_with(
318317
file_path="oci://bucket@namespace/prefix/config/README.md",
319318
auth=self.app._auth,
@@ -337,12 +336,12 @@ def test_get_foundation_models(
337336
"name": f"{ds_model.display_name}",
338337
"organization": f'{ds_model.freeform_tags["organization"]}',
339338
"project_id": f"{ds_model.project_id}",
340-
"ready_to_deploy": False if foundation_model_type == "shadow" else True,
339+
"ready_to_deploy": False if foundation_model_type == "verified" else True,
341340
"ready_to_finetune": False,
342-
"ready_to_import": True if foundation_model_type == "shadow" else False,
341+
"ready_to_import": True if foundation_model_type == "verified" else False,
343342
"search_text": (
344343
",test_license,test_organization,test_task,true"
345-
if foundation_model_type == "shadow"
344+
if foundation_model_type == "verified"
346345
else "ACTIVE,test_license,test_organization,test_task"
347346
),
348347
"tags": ds_model.freeform_tags,
@@ -524,7 +523,7 @@ def test_get_model_fine_tuned(
524523

525524
@patch("huggingface_hub.snapshot_download")
526525
@patch("subprocess.check_call")
527-
def test_import_shadow_model(
526+
def test_import_verified_model(
528527
self,
529528
mock_subprocess,
530529
mock_snapshot_download,
@@ -578,7 +577,6 @@ def test_import_shadow_model(
578577
mock_snapshot_download.assert_called_with(
579578
repo_id=hf_model,
580579
local_dir=f"{str(tmpdir)}/{hf_model}",
581-
local_dir_use_symlinks=False,
582580
)
583581
mock_subprocess.assert_called_with(
584582
shlex.split(
@@ -781,7 +779,6 @@ def test_import_any_hf_model_custom_container(
781779
mock_snapshot_download.assert_called_with(
782780
repo_id=hf_model,
783781
local_dir=f"{str(tmpdir)}/{hf_model}",
784-
local_dir_use_symlinks=False,
785782
)
786783
mock_subprocess.assert_called_with(
787784
shlex.split(
@@ -897,7 +894,6 @@ def test_import_any_hf_model_smc_container(
897894
mock_snapshot_download.assert_called_with(
898895
repo_id=hf_model,
899896
local_dir=f"{str(tmpdir)}/{hf_model}",
900-
local_dir_use_symlinks=False,
901897
)
902898
mock_subprocess.assert_called_with(
903899
shlex.split(

0 commit comments

Comments
 (0)