10
10
from oci .data_science .models import ModelDeployment
11
11
12
12
from ads .aqua .app import AquaApp , logger
13
- from ads .aqua .common .enums import Tags
13
+ from ads .aqua .common .enums import (
14
+ Tags ,
15
+ InferenceContainerParamType ,
16
+ InferenceContainerType ,
17
+ InferenceContainerTypeFamily ,
18
+ )
14
19
from ads .aqua .common .errors import AquaRuntimeError , AquaValueError
15
20
from ads .aqua .common .utils import (
16
21
get_container_config ,
38
43
AquaDeploymentDetail ,
39
44
ContainerSpec ,
40
45
)
41
- from ads .aqua .modeldeployment .constants import VLLMInferenceRestrictedParams
42
- from ads .aqua .modeldeployment .enums import (
43
- InferenceContainerParamType ,
44
- InferenceContainerType ,
45
- InferenceContainerTypeKey ,
46
+ from ads .aqua .modeldeployment .constants import (
47
+ VLLMInferenceRestrictedParams ,
48
+ TGIInferenceRestrictedParams ,
46
49
)
47
50
from ads .common .object_storage_details import ObjectStorageDetails
48
51
from ads .common .utils import get_log_links
@@ -106,6 +109,7 @@ def create(
106
109
server_port : int = None ,
107
110
health_check_port : int = None ,
108
111
env_var : Dict = None ,
112
+ container_family : str = None ,
109
113
) -> "AquaDeployment" :
110
114
"""
111
115
Creates a new Aqua deployment
@@ -144,6 +148,8 @@ def create(
144
148
The health check port for docker container image.
145
149
env_var : dict, optional
146
150
Environment variable for the deployment, by default None.
151
+ container_family: str
152
+ The image family of model deployment container runtime. Required for unverified Aqua models.
147
153
Returns
148
154
-------
149
155
AquaDeployment
@@ -227,9 +233,17 @@ def create(
227
233
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME
228
234
).value
229
235
except ValueError :
230
- raise AquaValueError (
231
- f"{ AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME } key is not available in the custom metadata field for model { aqua_model .id } "
236
+ message = (
237
+ f"{ AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME } key is not available in the custom metadata field "
238
+ f"for model { aqua_model .id } ."
232
239
)
240
+ logger .debug (message )
241
+ if not container_family :
242
+ raise AquaValueError (
243
+ f"{ message } . For unverified Aqua models, container_family parameter should be "
244
+ f"set and value can be one of { ', ' .join (InferenceContainerTypeFamily .values ())} ."
245
+ )
246
+ container_type_key = container_family
233
247
try :
234
248
# Check if the container override flag is set. If set, then the user has chosen custom image
235
249
if aqua_model .custom_metadata_list .get (
@@ -275,13 +289,12 @@ def create(
275
289
.get (InferenceContainerParamType .PARAM_TYPE_VLLM , UNKNOWN )
276
290
)
277
291
278
- # todo: add support for tgi once parameters are added to configs. _find_restricted_params can take in
279
- # additional parameter container_type_key and should validate against TGIInferenceRestrictedParams set for
280
- # restricted params.
281
292
# validate user provided params
282
293
user_params = env_var .get ("PARAMS" , UNKNOWN )
283
294
if user_params :
284
- restricted_params = self ._find_restricted_params (params , user_params )
295
+ restricted_params = self ._find_restricted_params (
296
+ params , user_params , container_type_key
297
+ )
285
298
if restricted_params :
286
299
raise AquaValueError (
287
300
f"Parameters { restricted_params } are set by Aqua "
@@ -559,7 +572,7 @@ def get_deployment_default_params(
559
572
560
573
if container_type_key :
561
574
container_type_key = container_type_key .lower ()
562
- if container_type_key in InferenceContainerTypeKey .values ():
575
+ if container_type_key in InferenceContainerTypeFamily .values ():
563
576
deployment_config = self .get_deployment_config (model_id )
564
577
config_parameters = (
565
578
deployment_config .get ("configuration" , UNKNOWN_DICT )
@@ -587,7 +600,10 @@ def get_deployment_default_params(
587
600
return default_params
588
601
589
602
def validate_deployment_params (
590
- self , model_id : str , params : List [str ] = None
603
+ self ,
604
+ model_id : str ,
605
+ params : List [str ] = None ,
606
+ container_family : str = None ,
591
607
) -> Dict :
592
608
"""Validate if the deployment parameters passed by the user can be overridden. Parameter values are not
593
609
validated, only param keys are validated.
@@ -596,9 +612,10 @@ def validate_deployment_params(
596
612
----------
597
613
model_id: str
598
614
The OCID of the Aqua model.
599
-
600
615
params : List[str], optional
601
616
Params passed by the user.
617
+ container_family: str
618
+ The image family of model deployment container runtime. Required for unverified Aqua models.
602
619
603
620
Returns
604
621
-------
@@ -613,18 +630,28 @@ def validate_deployment_params(
613
630
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME
614
631
).value
615
632
except ValueError :
616
- container_type_key = UNKNOWN
617
- logger . debug (
618
- f"{ AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME } key is not available in the custom metadata field for model { model_id } ."
633
+ message = (
634
+ f" { AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME } key is not available in the custom metadata field "
635
+ f"for model { model_id } ."
619
636
)
620
- if container_type_key :
621
- container_config = get_container_config ()
622
- container_spec = container_config .get (
623
- ContainerSpec .CONTAINER_SPEC , {}
624
- ).get (container_type_key , {})
625
- cli_params = container_spec .get (ContainerSpec .CLI_PARM , "" )
637
+ logger .debug (message )
626
638
627
- restricted_params = self ._find_restricted_params (cli_params , params )
639
+ if not container_family :
640
+ raise AquaValueError (
641
+ f"{ message } . For unverified Aqua models, container_family parameter should be "
642
+ f"set and value can be one of { ', ' .join (InferenceContainerTypeFamily .values ())} ."
643
+ )
644
+ container_type_key = container_family
645
+
646
+ container_config = get_container_config ()
647
+ container_spec = container_config .get (ContainerSpec .CONTAINER_SPEC , {}).get (
648
+ container_type_key , {}
649
+ )
650
+ cli_params = container_spec .get (ContainerSpec .CLI_PARM , "" )
651
+
652
+ restricted_params = self ._find_restricted_params (
653
+ cli_params , params , container_type_key
654
+ )
628
655
629
656
if restricted_params :
630
657
raise AquaValueError (
@@ -635,7 +662,9 @@ def validate_deployment_params(
635
662
636
663
@staticmethod
637
664
def _find_restricted_params (
638
- default_params : Union [str , List [str ]], user_params : Union [str , List [str ]]
665
+ default_params : Union [str , List [str ]],
666
+ user_params : Union [str , List [str ]],
667
+ container_family : str ,
639
668
) -> List [str ]:
640
669
"""Returns a list of restricted params that user chooses to override when creating an Aqua deployment.
641
670
The default parameters coming from the container index json file cannot be overridden. In addition to this,
@@ -647,6 +676,8 @@ def _find_restricted_params(
647
676
Inference container parameter string with default values.
648
677
user_params:
649
678
Inference container parameter string with user provided values.
679
+ container_family: str
680
+ The image family of model deployment container runtime.
650
681
651
682
Returns
652
683
-------
@@ -659,7 +690,17 @@ def _find_restricted_params(
659
690
user_params_dict = get_params_dict (user_params )
660
691
661
692
for key , items in user_params_dict .items ():
662
- if key in default_params_dict or key in VLLMInferenceRestrictedParams :
693
+ if (
694
+ key in default_params_dict
695
+ or (
696
+ InferenceContainerType .CONTAINER_TYPE_VLLM in container_family
697
+ and key in VLLMInferenceRestrictedParams
698
+ )
699
+ or (
700
+ InferenceContainerType .CONTAINER_TYPE_TGI in container_family
701
+ and key in TGIInferenceRestrictedParams
702
+ )
703
+ ):
663
704
restricted_params .append (key .lstrip ("--" ))
664
705
665
706
return restricted_params
0 commit comments