Skip to content

Commit 9599dab

Browse files
remove redundant code
1 parent 54f9435 commit 9599dab

File tree

2 files changed

+53
-33
lines changed

2 files changed

+53
-33
lines changed

ads/aqua/common/utils.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@
4242
UNKNOWN_JSON_STR,
4343
)
4444
from ads.aqua.data import AquaResourceIdentifier
45+
from ads.aqua.modeldeployment.constants import (
46+
TGIInferenceRestrictedParams,
47+
VLLMInferenceRestrictedParams,
48+
)
4549
from ads.common.auth import default_signer
4650
from ads.common.decorator.threaded import threaded
4751
from ads.common.extended_enum import ExtendedEnumMeta
@@ -862,7 +866,7 @@ def copy_model_config(artifact_path: str, os_path: str, auth: dict = None):
862866
logger.debug(f"Failed to copy config folder from {artifact_path} to {os_path}.")
863867

864868

865-
def get_container_params_type(container_type_name: str):
869+
def get_container_params_type(container_type_name: str) -> str:
866870
"""The utility function accepts the deployment container type name and returns the corresponding params name.
867871
Parameters
868872
----------
@@ -875,9 +879,31 @@ def get_container_params_type(container_type_name: str):
875879
876880
"""
877881
# check substring instead of direct match in case container_type_name changes in the future
878-
if InferenceContainerType.CONTAINER_TYPE_VLLM in container_type_name:
882+
if InferenceContainerType.CONTAINER_TYPE_VLLM in container_type_name.lower():
879883
return InferenceContainerParamType.PARAM_TYPE_VLLM
880-
elif InferenceContainerType.CONTAINER_TYPE_TGI in container_type_name:
884+
elif InferenceContainerType.CONTAINER_TYPE_TGI in container_type_name.lower():
881885
return InferenceContainerParamType.PARAM_TYPE_TGI
882886
else:
883887
return UNKNOWN
888+
889+
890+
def get_restricted_params_by_container(container_type_name: str) -> set:
891+
"""The utility function accepts the deployment container type name and returns a set of restricted params
892+
for that container.
893+
Parameters
894+
----------
895+
container_type_name: str
896+
type of deployment container, like odsc-vllm-serving or odsc-tgi-serving.
897+
898+
Returns
899+
-------
900+
InferenceContainerParamType value
901+
902+
"""
903+
# check substring instead of direct match in case container_type_name changes in the future
904+
if InferenceContainerType.CONTAINER_TYPE_VLLM in container_type_name.lower():
905+
return VLLMInferenceRestrictedParams
906+
elif InferenceContainerType.CONTAINER_TYPE_TGI in container_type_name.lower():
907+
return TGIInferenceRestrictedParams
908+
else:
909+
return set()

ads/aqua/modeldeployment/deployment.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from ads.aqua.app import AquaApp, logger
99
from ads.aqua.common.enums import (
10-
InferenceContainerType,
1110
InferenceContainerTypeFamily,
1211
Tags,
1312
)
@@ -22,6 +21,7 @@
2221
get_params_dict,
2322
get_params_list,
2423
get_resource_name,
24+
get_restricted_params_by_container,
2525
load_config,
2626
)
2727
from ads.aqua.constants import (
@@ -34,10 +34,6 @@
3434
from ads.aqua.data import AquaResourceIdentifier
3535
from ads.aqua.finetuning.finetuning import FineTuneCustomMetadata
3636
from ads.aqua.model import AquaModelApp
37-
from ads.aqua.modeldeployment.constants import (
38-
TGIInferenceRestrictedParams,
39-
VLLMInferenceRestrictedParams,
40-
)
4137
from ads.aqua.modeldeployment.entities import (
4238
AquaDeployment,
4339
AquaDeploymentDetail,
@@ -567,19 +563,27 @@ def get_deployment_default_params(
567563
f"{AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} key is not available in the custom metadata field for model {model_id}."
568564
)
569565

570-
if container_type_key:
571-
container_type_key = container_type_key.lower()
572-
if container_type_key in InferenceContainerTypeFamily.values():
573-
deployment_config = self.get_deployment_config(model_id)
574-
params = (
575-
deployment_config.get("configuration", UNKNOWN_DICT)
576-
.get(instance_shape, UNKNOWN_DICT)
577-
.get("parameters", UNKNOWN_DICT)
578-
.get(get_container_params_type(container_type_key))
566+
if (
567+
container_type_key
568+
and container_type_key in InferenceContainerTypeFamily.values()
569+
):
570+
deployment_config = self.get_deployment_config(model_id)
571+
config_params = (
572+
deployment_config.get("configuration", UNKNOWN_DICT)
573+
.get(instance_shape, UNKNOWN_DICT)
574+
.get("parameters", UNKNOWN_DICT)
575+
.get(get_container_params_type(container_type_key), UNKNOWN)
576+
)
577+
if config_params:
578+
params_list = get_params_list(config_params)
579+
restricted_params_set = get_restricted_params_by_container(
580+
container_type_key
579581
)
580-
if params:
581-
# account for param that can have --arg but no values, e.g. --trust-remote-code
582-
default_params.extend(get_params_list(params))
582+
583+
# remove restricted params from the list as user cannot override them during deployment
584+
for params in params_list:
585+
if params.split()[0] not in restricted_params_set:
586+
default_params.append(params)
583587

584588
return default_params
585589

@@ -651,8 +655,7 @@ def _find_restricted_params(
651655
container_family: str,
652656
) -> List[str]:
653657
"""Returns a list of restricted params that user chooses to override when creating an Aqua deployment.
654-
The default parameters coming from the container index json file cannot be overridden. In addition to this,
655-
a set of parameters maintained in
658+
The default parameters coming from the container index json file cannot be overridden.
656659
657660
Parameters
658661
----------
@@ -673,18 +676,9 @@ def _find_restricted_params(
673676
default_params_dict = get_params_dict(default_params)
674677
user_params_dict = get_params_dict(user_params)
675678

679+
restricted_params_set = get_restricted_params_by_container(container_family)
676680
for key, _items in user_params_dict.items():
677-
if (
678-
key in default_params_dict
679-
or (
680-
InferenceContainerType.CONTAINER_TYPE_VLLM in container_family
681-
and key in VLLMInferenceRestrictedParams
682-
)
683-
or (
684-
InferenceContainerType.CONTAINER_TYPE_TGI in container_family
685-
and key in TGIInferenceRestrictedParams
686-
)
687-
):
681+
if key in default_params_dict or key in restricted_params_set:
688682
restricted_params.append(key.lstrip("-"))
689683

690684
return restricted_params

0 commit comments

Comments
 (0)