Skip to content

Commit 4398d3b

Browse files
Add config validation and telemetry (#872)
2 parents 8b37351 + 994adae commit 4398d3b

File tree

3 files changed

+74
-27
lines changed

3 files changed

+74
-27
lines changed

ads/aqua/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
AQUA_GA_LIST = ["id19sfcrra6z"]
3333
AQUA_MODEL_TYPE_SERVICE = "service"
3434
AQUA_MODEL_TYPE_CUSTOM = "custom"
35+
AQUA_MODEL_ARTIFACT_CONFIG = "config.json"
36+
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME = "_name_or_path"
37+
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE = "model_type"
3538

3639
TRAINING_METRICS_FINAL = "training_metrics_final"
3740
VALIDATION_METRICS_FINAL = "validation_metrics_final"

ads/aqua/model/model.py

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
get_artifact_path,
2020
read_file,
2121
copy_model_config,
22+
load_config,
2223
)
2324
from ads.aqua.constants import (
2425
LICENSE_TXT,
@@ -32,12 +33,16 @@
3233
UNKNOWN,
3334
VALIDATION_METRICS,
3435
VALIDATION_METRICS_FINAL,
36+
AQUA_MODEL_ARTIFACT_CONFIG,
37+
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME,
38+
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE,
39+
AQUA_MODEL_TYPE_CUSTOM,
3540
)
3641
from ads.aqua.model.constants import *
3742
from ads.aqua.model.entities import *
3843
from ads.common.auth import default_signer
3944
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
40-
from ads.common.utils import get_console_link, is_path_exists
45+
from ads.common.utils import get_console_link
4146
from ads.config import (
4247
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME,
4348
AQUA_EVALUATION_CONTAINER_METADATA_NAME,
@@ -688,9 +693,16 @@ def register(
688693
if not import_model_details:
689694
import_model_details = ImportModelDetails(**kwargs)
690695

691-
if not is_path_exists(
692-
f"{import_model_details.os_path.rstrip('/')}/config.json"
693-
):
696+
try:
697+
model_config = load_config(
698+
file_path=import_model_details.os_path,
699+
config_file_name=AQUA_MODEL_ARTIFACT_CONFIG,
700+
)
701+
except Exception as ex:
702+
logger.error(
703+
f"Exception occurred while loading config file from {import_model_details.os_path}"
704+
f"Exception message: {ex}"
705+
)
694706
raise AquaRuntimeError(
695707
f"The model path {import_model_details.os_path} does not contain the file config.json. "
696708
f"Please check if the path is correct or the model artifacts are available at this location."
@@ -713,6 +725,30 @@ def register(
713725
)
714726
if model_service_id:
715727
verified_model_details = DataScienceModel.from_id(model_service_id)
728+
try:
729+
metadata_model_type = verified_model_details.custom_metadata_list.get(
730+
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
731+
).value
732+
if metadata_model_type:
733+
if AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config:
734+
if (
735+
model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]
736+
!= metadata_model_type
737+
):
738+
raise AquaRuntimeError(
739+
f"The {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in {AQUA_MODEL_ARTIFACT_CONFIG}"
740+
f" at {import_model_details.os_path} is invalid, expected {metadata_model_type} for "
741+
f"the model {import_model_details.model}. Please check if the path is correct or "
742+
f"the correct model artifacts are available at this location."
743+
f""
744+
)
745+
else:
746+
logger.debug(
747+
f"Could not find {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in "
748+
f"{AQUA_MODEL_ARTIFACT_CONFIG}. Proceeding with model registration."
749+
)
750+
except:
751+
pass
716752

717753
# Copy the model name from the service model if `model` is ocid
718754
model_name = (
@@ -734,19 +770,14 @@ def register(
734770
# registered model will always have inference and evaluation container, but
735771
# fine-tuning container may be not set
736772
inference_container = ds_model.custom_metadata_list.get(
737-
ModelCustomMetadataFields.DEPLOYMENT_CONTAINER,
738-
ModelCustomMetadataItem(key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER),
773+
ModelCustomMetadataFields.DEPLOYMENT_CONTAINER
739774
).value
740775
evaluation_container = ds_model.custom_metadata_list.get(
741776
ModelCustomMetadataFields.EVALUATION_CONTAINER,
742-
ModelCustomMetadataItem(key=ModelCustomMetadataFields.EVALUATION_CONTAINER),
743777
).value
744778
try:
745779
finetuning_container = ds_model.custom_metadata_list.get(
746780
ModelCustomMetadataFields.FINETUNE_CONTAINER,
747-
ModelCustomMetadataItem(
748-
key=ModelCustomMetadataFields.FINETUNE_CONTAINER
749-
),
750781
).value
751782
except:
752783
finetuning_container = None
@@ -756,18 +787,31 @@ def register(
756787
project_id=ds_model.project_id,
757788
model_card=str(
758789
read_file(
759-
file_path=(
760-
f"{import_model_details.os_path.rstrip('/')}/config/{README}"
761-
if verified_model_details
762-
else f"{import_model_details.os_path.rstrip('/')}/{README}"
763-
),
790+
file_path=f"{import_model_details.os_path.rstrip('/')}/{README}",
764791
auth=default_signer(),
765792
)
766793
),
767794
inference_container=inference_container,
768795
finetuning_container=finetuning_container,
769796
evaluation_container=evaluation_container,
770797
)
798+
799+
if verified_model_details:
800+
telemetry_model_name = model_name
801+
else:
802+
if AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config:
803+
telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}"
804+
elif AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config:
805+
telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}"
806+
else:
807+
telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
808+
809+
self.telemetry.record_event_async(
810+
category="aqua/model",
811+
action="register",
812+
detail=telemetry_model_name,
813+
)
814+
771815
return AquaModel(**aqua_model_attributes)
772816

773817
def _if_show(self, model: DataScienceModel) -> bool:

tests/unitary/with_extras/aqua/test_model.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
ModelProvenanceMetadata,
2727
ModelTaxonomyMetadata,
2828
)
29-
from ads.aqua.common.errors import AquaRuntimeError
29+
from ads.aqua.common.errors import AquaRuntimeError, AquaFileNotFoundError
3030
from ads.model.service.oci_datascience_model import OCIDataScienceModel
3131

3232

@@ -534,10 +534,10 @@ def test_get_model_fine_tuned(
534534
)
535535
@patch("ads.aqua.common.utils.copy_file")
536536
@patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects")
537-
@patch("ads.common.utils.is_path_exists", return_value=True)
537+
@patch("ads.aqua.common.utils.load_config", return_value={})
538538
def test_import_verified_model(
539539
self,
540-
mock_is_path_exists,
540+
mock_load_config,
541541
mock_list_objects,
542542
mock_copy_file,
543543
artifact_location_set,
@@ -612,7 +612,7 @@ def test_import_verified_model(
612612
"aqua_service_model": "test_model_id",
613613
**ds_freeform_tags,
614614
}
615-
mock_is_path_exists.assert_called()
615+
mock_load_config.assert_called()
616616

617617
assert model.inference_container == "odsc-tgi-serving"
618618
assert model.finetuning_container is None
@@ -621,8 +621,8 @@ def test_import_verified_model(
621621
assert model.ready_to_deploy is True
622622
assert model.ready_to_finetune is False
623623

624-
@patch("ads.common.utils.is_path_exists", return_value=True)
625-
def test_import_any_model_no_containers_specified(self, mock_is_path_exists):
624+
@patch("ads.aqua.common.utils.load_config", return_value={})
625+
def test_import_any_model_no_containers_specified(self, mock_load_config):
626626
ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True)
627627
ads.common.oci_datascience.OCIDataScienceMixin.init_client = MagicMock()
628628
DataScienceModel.upload_artifact = MagicMock()
@@ -655,8 +655,8 @@ def test_import_any_model_no_containers_specified(self, mock_is_path_exists):
655655
os_path=os_path,
656656
)
657657

658-
@patch("ads.common.utils.is_path_exists", return_value=True)
659-
def test_import_model_with_project_compartment_override(self, mock_is_path_exists):
658+
@patch("ads.aqua.common.utils.load_config", return_value={})
659+
def test_import_model_with_project_compartment_override(self, mock_load_config):
660660
ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True)
661661
ads.common.oci_datascience.OCIDataScienceMixin.init_client = MagicMock()
662662
DataScienceModel.upload_artifact = MagicMock()
@@ -704,8 +704,8 @@ def test_import_model_with_project_compartment_override(self, mock_is_path_exist
704704
assert model.compartment_id == compartment_override
705705
assert model.project_id == project_override
706706

707-
@patch("ads.common.utils.is_path_exists", return_value=False)
708-
def test_import_model_with_missing_artifact(self, mock_is_path_exists):
707+
@patch("ads.aqua.common.utils.load_config", side_effect=AquaFileNotFoundError)
708+
def test_import_model_with_missing_config(self, mock_load_config):
709709
"""Test for validating if error is returned when model artifacts are incomplete or not available."""
710710
os_path = "oci://aqua-bkt@aqua-ns/prefix/path"
711711
model_name = "oracle/aqua-1t-mega-model"
@@ -717,10 +717,10 @@ def test_import_model_with_missing_artifact(self, mock_is_path_exists):
717717
os_path=os_path,
718718
)
719719

720-
@patch("ads.common.utils.is_path_exists", return_value=True)
720+
@patch("ads.aqua.common.utils.load_config", return_value={})
721721
def test_import_any_model_smc_container(
722722
self,
723-
mock_is_path_exists,
723+
mock_load_config,
724724
):
725725
my_model = "oracle/aqua-1t-mega-model"
726726
ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True)

0 commit comments

Comments
 (0)