Skip to content

extract sample_ratio for automlx explainability #1167

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ads/aqua/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
HF_LOGIN_DEFAULT_TIMEOUT = 2
MODEL_NAME_DELIMITER = ";"
AQUA_TROUBLESHOOTING_LINK = "https://github.com/oracle-samples/oci-data-science-ai-samples/blob/main/ai-quick-actions/troubleshooting-tips.md"
MODEL_FILE_DESCRIPTION_VERSION = "1.0"
MODEL_FILE_DESCRIPTION_TYPE = "modelOSSReferenceDescription"

TRAINING_METRICS_FINAL = "training_metrics_final"
VALIDATION_METRICS_FINAL = "validation_metrics_final"
Expand Down
83 changes: 81 additions & 2 deletions ads/aqua/model/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,19 @@

import oci
from huggingface_hub import hf_api
from pydantic import BaseModel
from pydantic import BaseModel, Field
from pydantic.alias_generators import to_camel

from ads.aqua import logger
from ads.aqua.app import CLIBuilderMixin
from ads.aqua.common import utils
from ads.aqua.constants import LIFECYCLE_DETAILS_MISSING_JOBRUN, UNKNOWN_VALUE
from ads.aqua.config.utils.serializer import Serializable
from ads.aqua.constants import (
LIFECYCLE_DETAILS_MISSING_JOBRUN,
MODEL_FILE_DESCRIPTION_TYPE,
MODEL_FILE_DESCRIPTION_VERSION,
UNKNOWN_VALUE,
)
from ads.aqua.data import AquaResourceIdentifier
from ads.aqua.model.enums import FineTuningDefinedMetadata
from ads.aqua.training.exceptions import exit_code_dict
Expand Down Expand Up @@ -304,3 +311,75 @@ class ImportModelDetails(CLIBuilderMixin):

def __post_init__(self):
self._command = "model register"


class ModelFileInfo(Serializable):
"""Describes the file information of this model.

Attributes:
name (str): The name of the model artifact file.
version (str): The version of the model artifact file.
size_in_bytes (int): The size of the model artifact file in bytes.
"""

name: str = Field(..., description="The name of model artifact file.")
version: str = Field(..., description="The version of model artifact file.")
size_in_bytes: int = Field(
..., description="The size of model artifact file in bytes."
)

class Config:
alias_generator = to_camel
extra = "allow"


class ModelArtifactInfo(Serializable):
"""Describes the artifact information of this model.

Attributes:
namespace (str): The namespace of the model artifact location.
bucket_name (str): The bucket name of model artifact location.
prefix (str): The prefix of model artifact location.
objects: (List[ModelFileInfo]): A list of model artifact objects.
"""

namespace: str = Field(
..., description="The name space of model artifact location."
)
bucket_name: str = Field(
..., description="The bucket name of model artifact location."
)
prefix: str = Field(..., description="The prefix of model artifact location.")
objects: List[ModelFileInfo] = Field(
..., description="List of model artifact objects."
)

class Config:
alias_generator = to_camel
extra = "allow"


class ModelFileDescription(Serializable):
"""Describes the model file description.

Attributes:
version (str): The version of the model file description. Defaults to `1.0`.
type (str): The type of model file description. Defaults to `modelOSSReferenceDescription`.
models List[ModelArtifactInfo]: A list of model artifact information.
"""

version: str = Field(
default=MODEL_FILE_DESCRIPTION_VERSION,
description="The version of model file description.",
)
type: str = Field(
default=MODEL_FILE_DESCRIPTION_TYPE,
description="The type of model file description.",
)
models: List[ModelArtifactInfo] = Field(
..., description="List of model artifact information."
)

class Config:
alias_generator = to_camel
extra = "allow"
27 changes: 22 additions & 5 deletions ads/aqua/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
AquaModelReadme,
AquaModelSummary,
ImportModelDetails,
ModelFileDescription,
ModelValidationResult,
)
from ads.aqua.model.enums import MultiModelSupportedTaskType
Expand Down Expand Up @@ -271,8 +272,8 @@ def create_multi(
"Model list cannot be empty. Please provide at least one model for deployment."
)

artifact_list = []
display_name_list = []
model_file_description_list: List[ModelFileDescription] = []
model_custom_metadata = ModelCustomMetadata()

service_inference_containers = (
Expand All @@ -299,6 +300,7 @@ def create_multi(
for model in models:
source_model = DataScienceModel.from_id(model.model_id)
display_name = source_model.display_name
model_file_description = source_model.model_file_description
# Update model name in user's input model
model.model_name = model.model_name or display_name

Expand All @@ -324,7 +326,15 @@ def create_multi(
# Update model artifact location in user's input model
model.artifact_location = model_artifact_path

artifact_list.append(model_artifact_path)
if not model_file_description:
raise AquaValueError(
f"Model '{display_name}' (ID: {model.model_id}) has no file description. "
"Please register the model first."
)

model_file_description_list.append(
ModelFileDescription(**model_file_description)
)

# Validate deployment container consistency
deployment_container = source_model.custom_metadata_list.get(
Expand Down Expand Up @@ -402,9 +412,16 @@ def create_multi(
.with_custom_metadata_list(model_custom_metadata)
)

# Attach artifacts
for artifact in artifact_list:
custom_model.add_artifact(uri=artifact)
# Update multi model file description to attach artifacts
custom_model.with_model_file_description(
json_dict=ModelFileDescription(
models=[
models
for model_file_description in model_file_description_list
for models in model_file_description.models
]
).model_dump(by_alias=True)
)

# Finalize creation
custom_model.create(model_by_reference=True)
Expand Down
49 changes: 36 additions & 13 deletions tests/unitary/with_extras/aqua/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,35 @@ class TestDataset:
},
]

model_file_description = {
"version": "1.0",
"type": "modelOSSReferenceDescription",
"models": [
{
"namespace": "test_namespace",
"bucketName": "test_bucket",
"prefix": "models/meta-llama/Llama-3.2-3B-Instruct",
"objects": [
{
"name": "models/meta-llama/Llama-3.2-3B-Instruct/.gitattributes",
"version": "bfbf278c-10af-4f2c-8240-11fed02e1322",
"sizeInBytes": 1519,
},
{
"name": "models/meta-llama/Llama-3.2-3B-Instruct/LICENSE.txt",
"version": "4238d1e2-d826-4300-a344-0ead410afa27",
"sizeInBytes": 7712,
},
{
"name": "models/meta-llama/Llama-3.2-3B-Instruct/README.md",
"version": "57382552-9ad0-4546-b38c-c96634f3b8a2",
"sizeInBytes": 41744,
},
],
}
],
}

SERVICE_COMPARTMENT_ID = "ocid1.compartment.oc1..<OCID>"
COMPARTMENT_ID = "ocid1.compartment.oc1..<UNIQUE_OCID>"
SERVICE_MODEL_ID = "ocid1.datasciencemodel.oc1.iad.<OCID>"
Expand Down Expand Up @@ -360,24 +389,20 @@ def test_create_model(self, mock_from_id, mock_validate, mock_create):
)
assert model.provenance_metadata.training_id == "test_training_id"

@patch.object(DataScienceModel, "add_artifact")
@patch.object(DataScienceModel, "create_custom_metadata_artifact")
@patch.object(DataScienceModel, "create")
@patch("ads.model.datascience_model.validate")
@patch.object(AquaApp, "get_container_config")
@patch.object(DataScienceModel, "from_id")
def test_create_multimodel(
self,
mock_from_id,
mock_get_container_config,
mock_validate,
mock_create,
mock_create_custom_metadata_artifact,
mock_add_artifact,
):
mock_get_container_config.return_value = get_container_config()
mock_model = MagicMock()
mock_model.model_file_description = {"test_key": "test_value"}
mock_model.model_file_description = TestDataset.model_file_description
mock_model.display_name = "test_display_name"
mock_model.description = "test_description"
mock_model.freeform_tags = {
Expand All @@ -396,14 +421,14 @@ def test_create_multimodel(
model_info_1 = AquaMultiModelRef(
model_id="test_model_id_1",
gpu_count=2,
model_task = "text_embedding",
model_task="text_embedding",
env_var={"params": "--trust-remote-code --max-model-len 60000"},
)

model_info_2 = AquaMultiModelRef(
model_id="test_model_id_2",
gpu_count=2,
model_task = "image_text_to_text",
model_task="image_text_to_text",
env_var={"params": "--trust-remote-code --max-model-len 32000"},
)

Expand Down Expand Up @@ -455,10 +480,10 @@ def test_create_multimodel(
mock_model.freeform_tags["task"] = "unsupported_task"
with pytest.raises(AquaValueError):
model = self.app.create_multi(
models=[model_info_1, model_info_2],
project_id="test_project_id",
compartment_id="test_compartment_id",
)
models=[model_info_1, model_info_2],
project_id="test_project_id",
compartment_id="test_compartment_id",
)

mock_model.freeform_tags["task"] = "text-generation"
model_info_1.model_task = "text_embedding"
Expand All @@ -470,9 +495,7 @@ def test_create_multimodel(
compartment_id="test_compartment_id",
)

mock_add_artifact.assert_called()
mock_from_id.assert_called()
mock_validate.assert_not_called()
mock_create.assert_called_with(model_by_reference=True)

mock_model.compartment_id = TestDataset.SERVICE_COMPARTMENT_ID
Expand Down
Loading