Skip to content

Commit 57d930f

Browse files
Adding model task in MD tags
1 parent c034c5f commit 57d930f

File tree

1 file changed

+49
-48
lines changed

1 file changed

+49
-48
lines changed

ads/aqua/modeldeployment/deployment.py

Lines changed: 49 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
from ads.config import (
4747
AQUA_CONFIG_FOLDER,
4848
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME,
49-
AQUA_DEPLOYMENT_CONTAINER_OVERRIDE_FLAG_METADATA_NAME,
5049
AQUA_MODEL_DEPLOYMENT_CONFIG,
5150
AQUA_MODEL_DEPLOYMENT_CONFIG_DEFAULTS,
5251
COMPARTMENT_OCID,
@@ -87,26 +86,26 @@ class AquaDeploymentApp(AquaApp):
8786

8887
@telemetry(entry_point="plugin=deployment&action=create", name="aqua")
8988
def create(
90-
self,
91-
model_id: str,
92-
instance_shape: str,
93-
display_name: str,
94-
instance_count: int = None,
95-
log_group_id: str = None,
96-
access_log_id: str = None,
97-
predict_log_id: str = None,
98-
compartment_id: str = None,
99-
project_id: str = None,
100-
description: str = None,
101-
bandwidth_mbps: int = None,
102-
web_concurrency: int = None,
103-
server_port: int = None,
104-
health_check_port: int = None,
105-
env_var: Dict = None,
106-
container_family: str = None,
107-
memory_in_gbs: Optional[float] = None,
108-
ocpus: Optional[float] = None,
109-
model_file: Optional[str] = None,
89+
self,
90+
model_id: str,
91+
instance_shape: str,
92+
display_name: str,
93+
instance_count: int = None,
94+
log_group_id: str = None,
95+
access_log_id: str = None,
96+
predict_log_id: str = None,
97+
compartment_id: str = None,
98+
project_id: str = None,
99+
description: str = None,
100+
bandwidth_mbps: int = None,
101+
web_concurrency: int = None,
102+
server_port: int = None,
103+
health_check_port: int = None,
104+
env_var: Dict = None,
105+
container_family: str = None,
106+
memory_in_gbs: Optional[float] = None,
107+
ocpus: Optional[float] = None,
108+
model_file: Optional[str] = None,
110109
) -> "AquaDeployment":
111110
"""
112111
Creates a new Aqua deployment
@@ -175,6 +174,7 @@ def create(
175174
tags[tag] = aqua_model.freeform_tags[tag]
176175

177176
tags.update({Tags.AQUA_MODEL_NAME_TAG: aqua_model.display_name})
177+
tags.update({Tags.TASK: aqua_model.freeform_tags.get(Tags.TASK, None)})
178178

179179
# Set up info to get deployment config
180180
config_source_id = model_id
@@ -231,8 +231,7 @@ def create(
231231
env_var.update({"FT_MODEL": f"{fine_tune_output_path}"})
232232

233233
container_type_key = self._get_container_type_key(
234-
model=aqua_model,
235-
container_family=container_family
234+
model=aqua_model, container_family=container_family
236235
)
237236

238237
# fetch image name from config
@@ -248,7 +247,11 @@ def create(
248247
model_format = model_formats_str.split(",")
249248

250249
# Figure out a better way to handle this in future release
251-
if ModelFormat.GGUF.value in model_format and container_type_key.lower() == InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY:
250+
if (
251+
ModelFormat.GGUF.value in model_format
252+
and container_type_key.lower()
253+
== InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
254+
):
252255
if model_file is not None:
253256
logger.info(
254257
f"Overriding {model_file} as model_file for model {aqua_model.id}."
@@ -299,8 +302,8 @@ def create(
299302
if user_params:
300303
# todo: remove this check in the future version, logic to be moved to container_index
301304
if (
302-
container_type_key.lower()
303-
== InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
305+
container_type_key.lower()
306+
== InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
304307
):
305308
# AQUA_LLAMA_CPP_CONTAINER_FAMILY container uses uvicorn that required model/server params
306309
# to be set as env vars
@@ -422,9 +425,8 @@ def _get_container_type_key(model: DataScienceModel, container_family: str) -> s
422425
f"for model {model.id}. For unverified Aqua models, {AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} should be"
423426
f"set and value can be one of {', '.join(InferenceContainerTypeFamily.values())}."
424427
) from err
425-
428+
426429
return container_type_key
427-
428430

429431
@telemetry(entry_point="plugin=deployment&action=list", name="aqua")
430432
def list(self, **kwargs) -> List["AquaDeployment"]:
@@ -453,8 +455,8 @@ def list(self, **kwargs) -> List["AquaDeployment"]:
453455
for model_deployment in model_deployments:
454456
oci_aqua = (
455457
(
456-
Tags.AQUA_TAG in model_deployment.freeform_tags
457-
or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags
458+
Tags.AQUA_TAG in model_deployment.freeform_tags
459+
or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags
458460
)
459461
if model_deployment.freeform_tags
460462
else False
@@ -508,8 +510,8 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
508510

509511
oci_aqua = (
510512
(
511-
Tags.AQUA_TAG in model_deployment.freeform_tags
512-
or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags
513+
Tags.AQUA_TAG in model_deployment.freeform_tags
514+
or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags
513515
)
514516
if model_deployment.freeform_tags
515517
else False
@@ -526,8 +528,8 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
526528
log_group_name = ""
527529

528530
logs = (
529-
model_deployment.category_log_details.access
530-
or model_deployment.category_log_details.predict
531+
model_deployment.category_log_details.access
532+
or model_deployment.category_log_details.predict
531533
)
532534
if logs:
533535
log_id = logs.log_id
@@ -582,9 +584,9 @@ def get_deployment_config(self, model_id: str) -> Dict:
582584
return config
583585

584586
def get_deployment_default_params(
585-
self,
586-
model_id: str,
587-
instance_shape: str,
587+
self,
588+
model_id: str,
589+
instance_shape: str,
588590
) -> List[str]:
589591
"""Gets the default params set in the deployment configs for the given model and instance shape.
590592
@@ -616,8 +618,8 @@ def get_deployment_default_params(
616618
)
617619

618620
if (
619-
container_type_key
620-
and container_type_key in InferenceContainerTypeFamily.values()
621+
container_type_key
622+
and container_type_key in InferenceContainerTypeFamily.values()
621623
):
622624
deployment_config = self.get_deployment_config(model_id)
623625
config_params = (
@@ -640,10 +642,10 @@ def get_deployment_default_params(
640642
return default_params
641643

642644
def validate_deployment_params(
643-
self,
644-
model_id: str,
645-
params: List[str] = None,
646-
container_family: str = None,
645+
self,
646+
model_id: str,
647+
params: List[str] = None,
648+
container_family: str = None,
647649
) -> Dict:
648650
"""Validate if the deployment parameters passed by the user can be overridden. Parameter values are not
649651
validated, only param keys are validated.
@@ -666,8 +668,7 @@ def validate_deployment_params(
666668
if params:
667669
model = DataScienceModel.from_id(model_id)
668670
container_type_key = self._get_container_type_key(
669-
model=model,
670-
container_family=container_family
671+
model=model, container_family=container_family
671672
)
672673

673674
container_config = get_container_config()
@@ -689,9 +690,9 @@ def validate_deployment_params(
689690

690691
@staticmethod
691692
def _find_restricted_params(
692-
default_params: Union[str, List[str]],
693-
user_params: Union[str, List[str]],
694-
container_family: str,
693+
default_params: Union[str, List[str]],
694+
user_params: Union[str, List[str]],
695+
container_family: str,
695696
) -> List[str]:
696697
"""Returns a list of restricted params that user chooses to override when creating an Aqua deployment.
697698
The default parameters coming from the container index json file cannot be overridden.

0 commit comments

Comments
 (0)