Skip to content

[AQUA][WIP] Added support for deploy stack model. #1223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: feature/model_group
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 69 additions & 33 deletions ads/aqua/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@ def create(
freeform_tags: Optional[Dict] = None,
defined_tags: Optional[Dict] = None,
**kwargs,
) -> DataScienceModel:
) -> Union[DataScienceModel, DataScienceModelGroup]:
"""
Creates a custom Aqua model from a service model.
Creates a custom Aqua model or model group from a service model.

Parameters
----------
Expand All @@ -167,28 +167,21 @@ def create(

Returns
-------
DataScienceModel
The instance of DataScienceModel.
Union[DataScienceModel, DataScienceModelGroup]
The instance of DataScienceModel or DataScienceModelGroup.
"""
fine_tune_weights = (
model_id.fine_tune_weights
if isinstance(model_id, AquaMultiModelRef)
else []
)
model_id = (
model_id.model_id if isinstance(model_id, AquaMultiModelRef) else model_id
)
service_model = DataScienceModel.from_id(model_id)
target_project = project_id or PROJECT_OCID
target_compartment = compartment_id or COMPARTMENT_OCID

# Skip model copying if it is registered model or fine-tuned model
if (
service_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None) is not None
or service_model.freeform_tags.get(Tags.AQUA_FINE_TUNED_MODEL_TAG)
is not None
):
logger.info(
f"Aqua Model {model_id} already exists in the user's compartment."
"Skipped copying."
)
return service_model

# combine tags
combined_freeform_tags = {
**(service_model.freeform_tags or {}),
Expand All @@ -199,23 +192,66 @@ def create(
**(defined_tags or {}),
}

custom_model = (
DataScienceModel()
.with_compartment_id(target_compartment)
.with_project_id(target_project)
.with_model_file_description(json_dict=service_model.model_file_description)
.with_display_name(service_model.display_name)
.with_description(service_model.description)
.with_freeform_tags(**combined_freeform_tags)
.with_defined_tags(**combined_defined_tags)
.with_custom_metadata_list(service_model.custom_metadata_list)
.with_defined_metadata_list(service_model.defined_metadata_list)
.with_provenance_metadata(service_model.provenance_metadata)
.create(model_by_reference=True, **kwargs)
)
logger.info(
f"Aqua Model {custom_model.id} created with the service model {model_id}."
)
custom_model = None
if fine_tune_weights:
custom_model = (
DataScienceModelGroup()
.with_compartment_id(target_compartment)
.with_project_id(target_project)
.with_display_name(service_model.display_name)
.with_description(service_model.description)
.with_freeform_tags(**combined_freeform_tags)
.with_defined_tags(**combined_defined_tags)
.with_custom_metadata_list(service_model.custom_metadata_list)
.with_base_model_id(model_id)
.with_member_models(
[
{
"inference_key": fine_tune_weight.model_name,
"model_id": fine_tune_weight.model_id,
}
for fine_tune_weight in fine_tune_weights
]
)
.create()
)

logger.info(
f"Aqua Model Group {custom_model.id} created with the service model {model_id}."
)
else:
# Skip model copying if it is registered model or fine-tuned model
if (
service_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None)
is not None
or service_model.freeform_tags.get(Tags.AQUA_FINE_TUNED_MODEL_TAG)
is not None
):
logger.info(
f"Aqua Model {model_id} already exists in the user's compartment."
"Skipped copying."
)
return service_model

custom_model = (
DataScienceModel()
.with_compartment_id(target_compartment)
.with_project_id(target_project)
.with_model_file_description(
json_dict=service_model.model_file_description
)
.with_display_name(service_model.display_name)
.with_description(service_model.description)
.with_freeform_tags(**combined_freeform_tags)
.with_defined_tags(**combined_defined_tags)
.with_custom_metadata_list(service_model.custom_metadata_list)
.with_defined_metadata_list(service_model.defined_metadata_list)
.with_provenance_metadata(service_model.provenance_metadata)
.create(model_by_reference=True, **kwargs)
)
logger.info(
f"Aqua Model {custom_model.id} created with the service model {model_id}."
)

# Track unique models that were created in the user's compartment
self.telemetry.record_event_async(
Expand Down
1 change: 1 addition & 0 deletions ads/aqua/modeldeployment/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@

DEFAULT_WAIT_TIME = 12000
DEFAULT_POLL_INTERVAL = 10
DEFAULT_DEPLOYMENT_TYPE = "MODEL_STACK"
35 changes: 29 additions & 6 deletions ads/aqua/modeldeployment/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@
ModelDeploymentConfigSummary,
MultiModelDeploymentConfigLoader,
)
from ads.aqua.modeldeployment.constants import DEFAULT_POLL_INTERVAL, DEFAULT_WAIT_TIME
from ads.aqua.modeldeployment.constants import (
DEFAULT_DEPLOYMENT_TYPE,
DEFAULT_POLL_INTERVAL,
DEFAULT_WAIT_TIME,
)
from ads.aqua.modeldeployment.entities import (
AquaDeployment,
AquaDeploymentDetail,
Expand Down Expand Up @@ -162,6 +166,7 @@ def create(
cmd_var (Optional[List[str]]): Command variables for the container runtime.
freeform_tags (Optional[Dict]): Freeform tags for model deployment.
defined_tags (Optional[Dict]): Defined tags for model deployment.
deployment_type (Optional[str]): The type of model deployment.

Returns
-------
Expand Down Expand Up @@ -206,13 +211,28 @@ def create(

# Create an AquaModelApp instance once to perform the deployment creation.
model_app = AquaModelApp()
if create_deployment_details.model_id:
if (
create_deployment_details.model_id
or create_deployment_details.deployment_type == DEFAULT_DEPLOYMENT_TYPE
):
model_id = create_deployment_details.model_id
if not model_id:
if len(create_deployment_details.models) != 1:
raise AquaValueError(
"Invalid 'models' provided. Only one base model is required for model stack deployment."
)
model_id = create_deployment_details.models[0]

service_model_id = (
model_id if isinstance(model_id, str) else model_id.model_id
)
logger.debug(
f"Single model ({create_deployment_details.model_id}) provided. "
f"Single model ({service_model_id}) provided. "
"Delegating to single model creation method."
)

aqua_model = model_app.create(
model_id=create_deployment_details.model_id,
model_id=model_id,
compartment_id=compartment_id,
project_id=project_id,
freeform_tags=freeform_tags,
Expand Down Expand Up @@ -677,7 +697,7 @@ def _build_model_group_config(

def _create(
self,
aqua_model: DataScienceModel,
aqua_model: Union[DataScienceModel, DataScienceModelGroup],
create_deployment_details: CreateModelDeploymentDetails,
container_config: Dict,
) -> AquaDeployment:
Expand Down Expand Up @@ -711,7 +731,10 @@ def _create(
tags.update({Tags.TASK: aqua_model.freeform_tags.get(Tags.TASK, UNKNOWN)})

# Set up info to get deployment config
config_source_id = create_deployment_details.model_id
config_source_id = (
create_deployment_details.model_id
or create_deployment_details.models[0].model_id
)
model_name = aqua_model.display_name

# set up env and cmd var
Expand Down
3 changes: 3 additions & 0 deletions ads/aqua/modeldeployment/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,9 @@ class CreateModelDeploymentDetails(BaseModel):
defined_tags: Optional[Dict] = Field(
None, description="Defined tags for model deployment."
)
deployment_type: Optional[str] = Field(
None, description="The type of model deployment."
)

@model_validator(mode="before")
@classmethod
Expand Down
44 changes: 29 additions & 15 deletions ads/model/datascience_model_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ModelGroup,
ModelGroupDetails,
ModelGroupSummary,
StackedModelGroupDetails,
UpdateModelGroupDetails,
)
except ModuleNotFoundError as err:
Expand Down Expand Up @@ -511,28 +512,38 @@ def create(

def _build_model_group_details(self) -> dict:
"""Builds model group details dict for creating or updating oci model group."""
model_group_details = HomogeneousModelGroupDetails(
custom_metadata_list=[
CustomMetadata(
key=custom_metadata.key,
value=custom_metadata.value,
description=custom_metadata.description,
category=custom_metadata.category,
)
for custom_metadata in self.custom_metadata_list._to_oci_metadata()
]
)
custom_metadata_list = [
CustomMetadata(
key=custom_metadata.key,
value=custom_metadata.value,
description=custom_metadata.description,
category=custom_metadata.category,
)
for custom_metadata in self.custom_metadata_list._to_oci_metadata()
]
member_model_details = [
MemberModelDetails(**member_model) for member_model in self.member_models
]

if self.base_model_id:
model_group_details = StackedModelGroupDetails(
custom_metadata_list=custom_metadata_list,
base_model_id=self.base_model_id,
)
member_model_details.append(MemberModelDetails(model_id=self.base_model_id))
else:
model_group_details = HomogeneousModelGroupDetails(
custom_metadata_list=custom_metadata_list
)

member_model_entries = MemberModelEntries(
member_model_details=[
MemberModelDetails(**member_model)
for member_model in self.member_models
]
member_model_details=member_model_details
)

build_model_group_details = copy.deepcopy(self._spec)
build_model_group_details.pop(self.CONST_CUSTOM_METADATA_LIST)
build_model_group_details.pop(self.CONST_MEMBER_MODELS)
build_model_group_details.pop(self.CONST_BASE_MODEL_ID)
build_model_group_details.update(
{
self.CONST_COMPARTMENT_ID: self.compartment_id or COMPARTMENT_OCID,
Expand Down Expand Up @@ -581,6 +592,9 @@ def _update_from_oci_model(
)
self.set_spec(self.CONST_CUSTOM_METADATA_LIST, model_custom_metadata)

if hasattr(model_group_details, "base_model_id"):
self.set_spec(self.CONST_BASE_MODEL_ID, model_group_details.base_model_id)

# only updates member_models when oci_model_group_instance is an instance of
# oci.data_science.models.ModelGroup as oci.data_science.models.ModelGroupSummary
# doesn't have member_model_entries property.
Expand Down