Skip to content

Commit 0f0e302

Browse files
committed
Merge branch 'main' of https://github.com/oracle/accelerated-data-science into ODSC-64017/add_private_endpoint_for_md
2 parents b918f98 + b44e1a7 commit 0f0e302

File tree

2 files changed

+50
-50
lines changed

2 files changed

+50
-50
lines changed

ads/aqua/common/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1071,7 +1071,6 @@ def list_hf_models(query: str) -> List[str]:
10711071
try:
10721072
models = HfApi().list_models(
10731073
model_name=query,
1074-
task="text-generation",
10751074
sort="downloads",
10761075
direction=-1,
10771076
limit=20,

ads/aqua/modeldeployment/deployment.py

Lines changed: 50 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
from ads.config import (
4747
AQUA_CONFIG_FOLDER,
4848
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME,
49-
AQUA_DEPLOYMENT_CONTAINER_OVERRIDE_FLAG_METADATA_NAME,
5049
AQUA_MODEL_DEPLOYMENT_CONFIG,
5150
AQUA_MODEL_DEPLOYMENT_CONFIG_DEFAULTS,
5251
COMPARTMENT_OCID,
@@ -87,27 +86,27 @@ class AquaDeploymentApp(AquaApp):
8786

8887
@telemetry(entry_point="plugin=deployment&action=create", name="aqua")
8988
def create(
90-
self,
91-
model_id: str,
92-
instance_shape: str,
93-
display_name: str,
94-
instance_count: int = None,
95-
log_group_id: str = None,
96-
access_log_id: str = None,
97-
predict_log_id: str = None,
98-
compartment_id: str = None,
99-
project_id: str = None,
100-
description: str = None,
101-
bandwidth_mbps: int = None,
102-
web_concurrency: int = None,
103-
server_port: int = None,
104-
health_check_port: int = None,
105-
env_var: Dict = None,
106-
container_family: str = None,
107-
memory_in_gbs: Optional[float] = None,
108-
ocpus: Optional[float] = None,
109-
model_file: Optional[str] = None,
110-
private_endpoint_id: Optional[str] = None,
89+
self,
90+
model_id: str,
91+
instance_shape: str,
92+
display_name: str,
93+
instance_count: int = None,
94+
log_group_id: str = None,
95+
access_log_id: str = None,
96+
predict_log_id: str = None,
97+
compartment_id: str = None,
98+
project_id: str = None,
99+
description: str = None,
100+
bandwidth_mbps: int = None,
101+
web_concurrency: int = None,
102+
server_port: int = None,
103+
health_check_port: int = None,
104+
env_var: Dict = None,
105+
container_family: str = None,
106+
memory_in_gbs: Optional[float] = None,
107+
ocpus: Optional[float] = None,
108+
model_file: Optional[str] = None,
109+
private_endpoint_id: Optional[str] = None,
111110
) -> "AquaDeployment":
112111
"""
113112
Creates a new Aqua deployment
@@ -179,6 +178,7 @@ def create(
179178
tags[tag] = aqua_model.freeform_tags[tag]
180179

181180
tags.update({Tags.AQUA_MODEL_NAME_TAG: aqua_model.display_name})
181+
tags.update({Tags.TASK: aqua_model.freeform_tags.get(Tags.TASK, None)})
182182

183183
# Set up info to get deployment config
184184
config_source_id = model_id
@@ -235,8 +235,7 @@ def create(
235235
env_var.update({"FT_MODEL": f"{fine_tune_output_path}"})
236236

237237
container_type_key = self._get_container_type_key(
238-
model=aqua_model,
239-
container_family=container_family
238+
model=aqua_model, container_family=container_family
240239
)
241240

242241
# fetch image name from config
@@ -252,7 +251,11 @@ def create(
252251
model_format = model_formats_str.split(",")
253252

254253
# Figure out a better way to handle this in future release
255-
if ModelFormat.GGUF.value in model_format and container_type_key.lower() == InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY:
254+
if (
255+
ModelFormat.GGUF.value in model_format
256+
and container_type_key.lower()
257+
== InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
258+
):
256259
if model_file is not None:
257260
logger.info(
258261
f"Overriding {model_file} as model_file for model {aqua_model.id}."
@@ -303,8 +306,8 @@ def create(
303306
if user_params:
304307
# todo: remove this check in the future version, logic to be moved to container_index
305308
if (
306-
container_type_key.lower()
307-
== InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
309+
container_type_key.lower()
310+
== InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
308311
):
309312
# AQUA_LLAMA_CPP_CONTAINER_FAMILY container uses uvicorn that required model/server params
310313
# to be set as env vars
@@ -427,9 +430,8 @@ def _get_container_type_key(model: DataScienceModel, container_family: str) -> s
427430
f"for model {model.id}. For unverified Aqua models, {AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} should be"
428431
f"set and value can be one of {', '.join(InferenceContainerTypeFamily.values())}."
429432
) from err
430-
433+
431434
return container_type_key
432-
433435

434436
@telemetry(entry_point="plugin=deployment&action=list", name="aqua")
435437
def list(self, **kwargs) -> List["AquaDeployment"]:
@@ -458,8 +460,8 @@ def list(self, **kwargs) -> List["AquaDeployment"]:
458460
for model_deployment in model_deployments:
459461
oci_aqua = (
460462
(
461-
Tags.AQUA_TAG in model_deployment.freeform_tags
462-
or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags
463+
Tags.AQUA_TAG in model_deployment.freeform_tags
464+
or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags
463465
)
464466
if model_deployment.freeform_tags
465467
else False
@@ -513,8 +515,8 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
513515

514516
oci_aqua = (
515517
(
516-
Tags.AQUA_TAG in model_deployment.freeform_tags
517-
or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags
518+
Tags.AQUA_TAG in model_deployment.freeform_tags
519+
or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags
518520
)
519521
if model_deployment.freeform_tags
520522
else False
@@ -531,8 +533,8 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
531533
log_group_name = ""
532534

533535
logs = (
534-
model_deployment.category_log_details.access
535-
or model_deployment.category_log_details.predict
536+
model_deployment.category_log_details.access
537+
or model_deployment.category_log_details.predict
536538
)
537539
if logs:
538540
log_id = logs.log_id
@@ -587,9 +589,9 @@ def get_deployment_config(self, model_id: str) -> Dict:
587589
return config
588590

589591
def get_deployment_default_params(
590-
self,
591-
model_id: str,
592-
instance_shape: str,
592+
self,
593+
model_id: str,
594+
instance_shape: str,
593595
) -> List[str]:
594596
"""Gets the default params set in the deployment configs for the given model and instance shape.
595597
@@ -621,8 +623,8 @@ def get_deployment_default_params(
621623
)
622624

623625
if (
624-
container_type_key
625-
and container_type_key in InferenceContainerTypeFamily.values()
626+
container_type_key
627+
and container_type_key in InferenceContainerTypeFamily.values()
626628
):
627629
deployment_config = self.get_deployment_config(model_id)
628630
config_params = (
@@ -645,10 +647,10 @@ def get_deployment_default_params(
645647
return default_params
646648

647649
def validate_deployment_params(
648-
self,
649-
model_id: str,
650-
params: List[str] = None,
651-
container_family: str = None,
650+
self,
651+
model_id: str,
652+
params: List[str] = None,
653+
container_family: str = None,
652654
) -> Dict:
653655
"""Validate if the deployment parameters passed by the user can be overridden. Parameter values are not
654656
validated, only param keys are validated.
@@ -671,8 +673,7 @@ def validate_deployment_params(
671673
if params:
672674
model = DataScienceModel.from_id(model_id)
673675
container_type_key = self._get_container_type_key(
674-
model=model,
675-
container_family=container_family
676+
model=model, container_family=container_family
676677
)
677678

678679
container_config = get_container_config()
@@ -694,9 +695,9 @@ def validate_deployment_params(
694695

695696
@staticmethod
696697
def _find_restricted_params(
697-
default_params: Union[str, List[str]],
698-
user_params: Union[str, List[str]],
699-
container_family: str,
698+
default_params: Union[str, List[str]],
699+
user_params: Union[str, List[str]],
700+
container_family: str,
700701
) -> List[str]:
701702
"""Returns a list of restricted params that user chooses to override when creating an Aqua deployment.
702703
The default parameters coming from the container index json file cannot be overridden.

0 commit comments

Comments
 (0)