diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index f7204fd72..878db2993 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -147,9 +147,9 @@ def create( freeform_tags: Optional[Dict] = None, defined_tags: Optional[Dict] = None, **kwargs, - ) -> DataScienceModel: + ) -> Union[DataScienceModel, DataScienceModelGroup]: """ - Creates a custom Aqua model from a service model. + Creates a custom Aqua model or model group from a service model. Parameters ---------- @@ -167,9 +167,14 @@ def create( Returns ------- - DataScienceModel - The instance of DataScienceModel. + Union[DataScienceModel, DataScienceModelGroup] + The instance of DataScienceModel or DataScienceModelGroup. """ + fine_tune_weights = ( + model_id.fine_tune_weights + if isinstance(model_id, AquaMultiModelRef) + else [] + ) model_id = ( model_id.model_id if isinstance(model_id, AquaMultiModelRef) else model_id ) @@ -177,18 +182,6 @@ def create( target_project = project_id or PROJECT_OCID target_compartment = compartment_id or COMPARTMENT_OCID - # Skip model copying if it is registered model or fine-tuned model - if ( - service_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None) is not None - or service_model.freeform_tags.get(Tags.AQUA_FINE_TUNED_MODEL_TAG) - is not None - ): - logger.info( - f"Aqua Model {model_id} already exists in the user's compartment." - "Skipped copying." - ) - return service_model - # combine tags combined_freeform_tags = { **(service_model.freeform_tags or {}), @@ -199,23 +192,66 @@ def create( **(defined_tags or {}), } - custom_model = ( - DataScienceModel() - .with_compartment_id(target_compartment) - .with_project_id(target_project) - .with_model_file_description(json_dict=service_model.model_file_description) - .with_display_name(service_model.display_name) - .with_description(service_model.description) - .with_freeform_tags(**combined_freeform_tags) - .with_defined_tags(**combined_defined_tags) - .with_custom_metadata_list(service_model.custom_metadata_list) - .with_defined_metadata_list(service_model.defined_metadata_list) - .with_provenance_metadata(service_model.provenance_metadata) - .create(model_by_reference=True, **kwargs) - ) - logger.info( - f"Aqua Model {custom_model.id} created with the service model {model_id}." - ) + custom_model = None + if fine_tune_weights: + custom_model = ( + DataScienceModelGroup() + .with_compartment_id(target_compartment) + .with_project_id(target_project) + .with_display_name(service_model.display_name) + .with_description(service_model.description) + .with_freeform_tags(**combined_freeform_tags) + .with_defined_tags(**combined_defined_tags) + .with_custom_metadata_list(service_model.custom_metadata_list) + .with_base_model_id(model_id) + .with_member_models( + [ + { + "inference_key": fine_tune_weight.model_name, + "model_id": fine_tune_weight.model_id, + } + for fine_tune_weight in fine_tune_weights + ] + ) + .create() + ) + + logger.info( + f"Aqua Model Group {custom_model.id} created with the service model {model_id}." + ) + else: + # Skip model copying if it is registered model or fine-tuned model + if ( + service_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None) + is not None + or service_model.freeform_tags.get(Tags.AQUA_FINE_TUNED_MODEL_TAG) + is not None + ): + logger.info( + f"Aqua Model {model_id} already exists in the user's compartment." + "Skipped copying." + ) + return service_model + + custom_model = ( + DataScienceModel() + .with_compartment_id(target_compartment) + .with_project_id(target_project) + .with_model_file_description( + json_dict=service_model.model_file_description + ) + .with_display_name(service_model.display_name) + .with_description(service_model.description) + .with_freeform_tags(**combined_freeform_tags) + .with_defined_tags(**combined_defined_tags) + .with_custom_metadata_list(service_model.custom_metadata_list) + .with_defined_metadata_list(service_model.defined_metadata_list) + .with_provenance_metadata(service_model.provenance_metadata) + .create(model_by_reference=True, **kwargs) + ) + logger.info( + f"Aqua Model {custom_model.id} created with the service model {model_id}." + ) # Track unique models that were created in the user's compartment self.telemetry.record_event_async( diff --git a/ads/aqua/modeldeployment/constants.py b/ads/aqua/modeldeployment/constants.py index a37699301..7996b1418 100644 --- a/ads/aqua/modeldeployment/constants.py +++ b/ads/aqua/modeldeployment/constants.py @@ -11,3 +11,4 @@ DEFAULT_WAIT_TIME = 12000 DEFAULT_POLL_INTERVAL = 10 +DEFAULT_DEPLOYMENT_TYPE = "MODEL_STACK" diff --git a/ads/aqua/modeldeployment/deployment.py b/ads/aqua/modeldeployment/deployment.py index c000c9059..bc2e58e2f 100644 --- a/ads/aqua/modeldeployment/deployment.py +++ b/ads/aqua/modeldeployment/deployment.py @@ -61,7 +61,11 @@ ModelDeploymentConfigSummary, MultiModelDeploymentConfigLoader, ) -from ads.aqua.modeldeployment.constants import DEFAULT_POLL_INTERVAL, DEFAULT_WAIT_TIME +from ads.aqua.modeldeployment.constants import ( + DEFAULT_DEPLOYMENT_TYPE, + DEFAULT_POLL_INTERVAL, + DEFAULT_WAIT_TIME, +) from ads.aqua.modeldeployment.entities import ( AquaDeployment, AquaDeploymentDetail, @@ -162,6 +166,7 @@ def create( cmd_var (Optional[List[str]]): Command variables for the container runtime. freeform_tags (Optional[Dict]): Freeform tags for model deployment. defined_tags (Optional[Dict]): Defined tags for model deployment. + deployment_type (Optional[str]): The type of model deployment. Returns ------- @@ -206,13 +211,28 @@ def create( # Create an AquaModelApp instance once to perform the deployment creation. model_app = AquaModelApp() - if create_deployment_details.model_id: + if ( + create_deployment_details.model_id + or create_deployment_details.deployment_type == DEFAULT_DEPLOYMENT_TYPE + ): + model_id = create_deployment_details.model_id + if not model_id: + if len(create_deployment_details.models) != 1: + raise AquaValueError( + "Invalid 'models' provided. Only one base model is required for model stack deployment." + ) + model_id = create_deployment_details.models[0] + + service_model_id = ( + model_id if isinstance(model_id, str) else model_id.model_id + ) logger.debug( - f"Single model ({create_deployment_details.model_id}) provided. " + f"Single model ({service_model_id}) provided. " "Delegating to single model creation method." ) + aqua_model = model_app.create( - model_id=create_deployment_details.model_id, + model_id=model_id, compartment_id=compartment_id, project_id=project_id, freeform_tags=freeform_tags, @@ -677,7 +697,7 @@ def _build_model_group_config( def _create( self, - aqua_model: DataScienceModel, + aqua_model: Union[DataScienceModel, DataScienceModelGroup], create_deployment_details: CreateModelDeploymentDetails, container_config: Dict, ) -> AquaDeployment: @@ -711,7 +731,10 @@ def _create( tags.update({Tags.TASK: aqua_model.freeform_tags.get(Tags.TASK, UNKNOWN)}) # Set up info to get deployment config - config_source_id = create_deployment_details.model_id + config_source_id = ( + create_deployment_details.model_id + or create_deployment_details.models[0].model_id + ) model_name = aqua_model.display_name # set up env and cmd var diff --git a/ads/aqua/modeldeployment/entities.py b/ads/aqua/modeldeployment/entities.py index 0b65bc213..3a95dc37c 100644 --- a/ads/aqua/modeldeployment/entities.py +++ b/ads/aqua/modeldeployment/entities.py @@ -325,6 +325,9 @@ class CreateModelDeploymentDetails(BaseModel): defined_tags: Optional[Dict] = Field( None, description="Defined tags for model deployment." ) + deployment_type: Optional[str] = Field( + None, description="The type of model deployment." + ) @model_validator(mode="before") @classmethod diff --git a/ads/model/datascience_model_group.py b/ads/model/datascience_model_group.py index cc32ffa9c..247378231 100644 --- a/ads/model/datascience_model_group.py +++ b/ads/model/datascience_model_group.py @@ -22,6 +22,7 @@ ModelGroup, ModelGroupDetails, ModelGroupSummary, + StackedModelGroupDetails, UpdateModelGroupDetails, ) except ModuleNotFoundError as err: @@ -511,28 +512,38 @@ def create( def _build_model_group_details(self) -> dict: """Builds model group details dict for creating or updating oci model group.""" - model_group_details = HomogeneousModelGroupDetails( - custom_metadata_list=[ - CustomMetadata( - key=custom_metadata.key, - value=custom_metadata.value, - description=custom_metadata.description, - category=custom_metadata.category, - ) - for custom_metadata in self.custom_metadata_list._to_oci_metadata() - ] - ) + custom_metadata_list = [ + CustomMetadata( + key=custom_metadata.key, + value=custom_metadata.value, + description=custom_metadata.description, + category=custom_metadata.category, + ) + for custom_metadata in self.custom_metadata_list._to_oci_metadata() + ] + member_model_details = [ + MemberModelDetails(**member_model) for member_model in self.member_models + ] + + if self.base_model_id: + model_group_details = StackedModelGroupDetails( + custom_metadata_list=custom_metadata_list, + base_model_id=self.base_model_id, + ) + member_model_details.append(MemberModelDetails(model_id=self.base_model_id)) + else: + model_group_details = HomogeneousModelGroupDetails( + custom_metadata_list=custom_metadata_list + ) member_model_entries = MemberModelEntries( - member_model_details=[ - MemberModelDetails(**member_model) - for member_model in self.member_models - ] + member_model_details=member_model_details ) build_model_group_details = copy.deepcopy(self._spec) build_model_group_details.pop(self.CONST_CUSTOM_METADATA_LIST) build_model_group_details.pop(self.CONST_MEMBER_MODELS) + build_model_group_details.pop(self.CONST_BASE_MODEL_ID) build_model_group_details.update( { self.CONST_COMPARTMENT_ID: self.compartment_id or COMPARTMENT_OCID, @@ -581,6 +592,9 @@ def _update_from_oci_model( ) self.set_spec(self.CONST_CUSTOM_METADATA_LIST, model_custom_metadata) + if hasattr(model_group_details, "base_model_id"): + self.set_spec(self.CONST_BASE_MODEL_ID, model_group_details.base_model_id) + # only updates member_models when oci_model_group_instance is an instance of # oci.data_science.models.ModelGroup as oci.data_science.models.ModelGroupSummary # doesn't have member_model_entries property.