7
7
8
8
from ads .aqua .app import AquaApp , logger
9
9
from ads .aqua .common .enums import (
10
- InferenceContainerType ,
11
10
InferenceContainerTypeFamily ,
12
11
Tags ,
13
12
)
22
21
get_params_dict ,
23
22
get_params_list ,
24
23
get_resource_name ,
24
+ get_restricted_params_by_container ,
25
25
load_config ,
26
26
)
27
27
from ads .aqua .constants import (
34
34
from ads .aqua .data import AquaResourceIdentifier
35
35
from ads .aqua .finetuning .finetuning import FineTuneCustomMetadata
36
36
from ads .aqua .model import AquaModelApp
37
- from ads .aqua .modeldeployment .constants import (
38
- TGIInferenceRestrictedParams ,
39
- VLLMInferenceRestrictedParams ,
40
- )
41
37
from ads .aqua .modeldeployment .entities import (
42
38
AquaDeployment ,
43
39
AquaDeploymentDetail ,
@@ -567,19 +563,27 @@ def get_deployment_default_params(
567
563
f"{ AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME } key is not available in the custom metadata field for model { model_id } ."
568
564
)
569
565
570
- if container_type_key :
571
- container_type_key = container_type_key .lower ()
572
- if container_type_key in InferenceContainerTypeFamily .values ():
573
- deployment_config = self .get_deployment_config (model_id )
574
- params = (
575
- deployment_config .get ("configuration" , UNKNOWN_DICT )
576
- .get (instance_shape , UNKNOWN_DICT )
577
- .get ("parameters" , UNKNOWN_DICT )
578
- .get (get_container_params_type (container_type_key ))
566
+ if (
567
+ container_type_key
568
+ and container_type_key in InferenceContainerTypeFamily .values ()
569
+ ):
570
+ deployment_config = self .get_deployment_config (model_id )
571
+ config_params = (
572
+ deployment_config .get ("configuration" , UNKNOWN_DICT )
573
+ .get (instance_shape , UNKNOWN_DICT )
574
+ .get ("parameters" , UNKNOWN_DICT )
575
+ .get (get_container_params_type (container_type_key ), UNKNOWN )
576
+ )
577
+ if config_params :
578
+ params_list = get_params_list (config_params )
579
+ restricted_params_set = get_restricted_params_by_container (
580
+ container_type_key
579
581
)
580
- if params :
581
- # account for param that can have --arg but no values, e.g. --trust-remote-code
582
- default_params .extend (get_params_list (params ))
582
+
583
+ # remove restricted params from the list as user cannot override them during deployment
584
+ for params in params_list :
585
+ if params .split ()[0 ] not in restricted_params_set :
586
+ default_params .append (params )
583
587
584
588
return default_params
585
589
@@ -651,8 +655,7 @@ def _find_restricted_params(
651
655
container_family : str ,
652
656
) -> List [str ]:
653
657
"""Returns a list of restricted params that user chooses to override when creating an Aqua deployment.
654
- The default parameters coming from the container index json file cannot be overridden. In addition to this,
655
- a set of parameters maintained in
658
+ The default parameters coming from the container index json file cannot be overridden.
656
659
657
660
Parameters
658
661
----------
@@ -673,18 +676,9 @@ def _find_restricted_params(
673
676
default_params_dict = get_params_dict (default_params )
674
677
user_params_dict = get_params_dict (user_params )
675
678
679
+ restricted_params_set = get_restricted_params_by_container (container_family )
676
680
for key , _items in user_params_dict .items ():
677
- if (
678
- key in default_params_dict
679
- or (
680
- InferenceContainerType .CONTAINER_TYPE_VLLM in container_family
681
- and key in VLLMInferenceRestrictedParams
682
- )
683
- or (
684
- InferenceContainerType .CONTAINER_TYPE_TGI in container_family
685
- and key in TGIInferenceRestrictedParams
686
- )
687
- ):
681
+ if key in default_params_dict or key in restricted_params_set :
688
682
restricted_params .append (key .lstrip ("-" ))
689
683
690
684
return restricted_params
0 commit comments