46
46
from ads .config import (
47
47
AQUA_CONFIG_FOLDER ,
48
48
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME ,
49
- AQUA_DEPLOYMENT_CONTAINER_OVERRIDE_FLAG_METADATA_NAME ,
50
49
AQUA_MODEL_DEPLOYMENT_CONFIG ,
51
50
AQUA_MODEL_DEPLOYMENT_CONFIG_DEFAULTS ,
52
51
COMPARTMENT_OCID ,
@@ -87,26 +86,27 @@ class AquaDeploymentApp(AquaApp):
87
86
88
87
@telemetry (entry_point = "plugin=deployment&action=create" , name = "aqua" )
89
88
def create (
90
- self ,
91
- model_id : str ,
92
- instance_shape : str ,
93
- display_name : str ,
94
- instance_count : int = None ,
95
- log_group_id : str = None ,
96
- access_log_id : str = None ,
97
- predict_log_id : str = None ,
98
- compartment_id : str = None ,
99
- project_id : str = None ,
100
- description : str = None ,
101
- bandwidth_mbps : int = None ,
102
- web_concurrency : int = None ,
103
- server_port : int = None ,
104
- health_check_port : int = None ,
105
- env_var : Dict = None ,
106
- container_family : str = None ,
107
- memory_in_gbs : Optional [float ] = None ,
108
- ocpus : Optional [float ] = None ,
109
- model_file : Optional [str ] = None ,
89
+ self ,
90
+ model_id : str ,
91
+ instance_shape : str ,
92
+ display_name : str ,
93
+ instance_count : int = None ,
94
+ log_group_id : str = None ,
95
+ access_log_id : str = None ,
96
+ predict_log_id : str = None ,
97
+ compartment_id : str = None ,
98
+ project_id : str = None ,
99
+ description : str = None ,
100
+ bandwidth_mbps : int = None ,
101
+ web_concurrency : int = None ,
102
+ server_port : int = None ,
103
+ health_check_port : int = None ,
104
+ env_var : Dict = None ,
105
+ container_family : str = None ,
106
+ memory_in_gbs : Optional [float ] = None ,
107
+ ocpus : Optional [float ] = None ,
108
+ model_file : Optional [str ] = None ,
109
+ cmd_var : List [str ] = None ,
110
110
) -> "AquaDeployment" :
111
111
"""
112
112
Creates a new Aqua deployment
@@ -153,6 +153,8 @@ def create(
153
153
The ocpu count for the shape selected.
154
154
model_file: str
155
155
The file used for model deployment.
156
+ cmd_var: List[str]
157
+ The cmd of model deployment container runtime.
156
158
Returns
157
159
-------
158
160
AquaDeployment
@@ -231,8 +233,7 @@ def create(
231
233
env_var .update ({"FT_MODEL" : f"{ fine_tune_output_path } " })
232
234
233
235
container_type_key = self ._get_container_type_key (
234
- model = aqua_model ,
235
- container_family = container_family
236
+ model = aqua_model , container_family = container_family
236
237
)
237
238
238
239
# fetch image name from config
@@ -248,7 +249,11 @@ def create(
248
249
model_format = model_formats_str .split ("," )
249
250
250
251
# Figure out a better way to handle this in future release
251
- if ModelFormat .GGUF .value in model_format and container_type_key .lower () == InferenceContainerTypeFamily .AQUA_LLAMA_CPP_CONTAINER_FAMILY :
252
+ if (
253
+ ModelFormat .GGUF .value in model_format
254
+ and container_type_key .lower ()
255
+ == InferenceContainerTypeFamily .AQUA_LLAMA_CPP_CONTAINER_FAMILY
256
+ ):
252
257
if model_file is not None :
253
258
logger .info (
254
259
f"Overriding { model_file } as model_file for model { aqua_model .id } ."
@@ -299,8 +304,8 @@ def create(
299
304
if user_params :
300
305
# todo: remove this check in the future version, logic to be moved to container_index
301
306
if (
302
- container_type_key .lower ()
303
- == InferenceContainerTypeFamily .AQUA_LLAMA_CPP_CONTAINER_FAMILY
307
+ container_type_key .lower ()
308
+ == InferenceContainerTypeFamily .AQUA_LLAMA_CPP_CONTAINER_FAMILY
304
309
):
305
310
# AQUA_LLAMA_CPP_CONTAINER_FAMILY container uses uvicorn that required model/server params
306
311
# to be set as env vars
@@ -369,6 +374,8 @@ def create(
369
374
.with_overwrite_existing_artifact (True )
370
375
.with_remove_existing_artifact (True )
371
376
)
377
+ if cmd_var :
378
+ container_runtime .with_cmd (cmd_var )
372
379
373
380
# configure model deployment and deploy model on container runtime
374
381
deployment = (
@@ -422,9 +429,8 @@ def _get_container_type_key(model: DataScienceModel, container_family: str) -> s
422
429
f"for model { model .id } . For unverified Aqua models, { AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME } should be"
423
430
f"set and value can be one of { ', ' .join (InferenceContainerTypeFamily .values ())} ."
424
431
) from err
425
-
432
+
426
433
return container_type_key
427
-
428
434
429
435
@telemetry (entry_point = "plugin=deployment&action=list" , name = "aqua" )
430
436
def list (self , ** kwargs ) -> List ["AquaDeployment" ]:
@@ -453,8 +459,8 @@ def list(self, **kwargs) -> List["AquaDeployment"]:
453
459
for model_deployment in model_deployments :
454
460
oci_aqua = (
455
461
(
456
- Tags .AQUA_TAG in model_deployment .freeform_tags
457
- or Tags .AQUA_TAG .lower () in model_deployment .freeform_tags
462
+ Tags .AQUA_TAG in model_deployment .freeform_tags
463
+ or Tags .AQUA_TAG .lower () in model_deployment .freeform_tags
458
464
)
459
465
if model_deployment .freeform_tags
460
466
else False
@@ -508,8 +514,8 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
508
514
509
515
oci_aqua = (
510
516
(
511
- Tags .AQUA_TAG in model_deployment .freeform_tags
512
- or Tags .AQUA_TAG .lower () in model_deployment .freeform_tags
517
+ Tags .AQUA_TAG in model_deployment .freeform_tags
518
+ or Tags .AQUA_TAG .lower () in model_deployment .freeform_tags
513
519
)
514
520
if model_deployment .freeform_tags
515
521
else False
@@ -526,8 +532,8 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
526
532
log_group_name = ""
527
533
528
534
logs = (
529
- model_deployment .category_log_details .access
530
- or model_deployment .category_log_details .predict
535
+ model_deployment .category_log_details .access
536
+ or model_deployment .category_log_details .predict
531
537
)
532
538
if logs :
533
539
log_id = logs .log_id
@@ -582,9 +588,9 @@ def get_deployment_config(self, model_id: str) -> Dict:
582
588
return config
583
589
584
590
def get_deployment_default_params (
585
- self ,
586
- model_id : str ,
587
- instance_shape : str ,
591
+ self ,
592
+ model_id : str ,
593
+ instance_shape : str ,
588
594
) -> List [str ]:
589
595
"""Gets the default params set in the deployment configs for the given model and instance shape.
590
596
@@ -616,8 +622,8 @@ def get_deployment_default_params(
616
622
)
617
623
618
624
if (
619
- container_type_key
620
- and container_type_key in InferenceContainerTypeFamily .values ()
625
+ container_type_key
626
+ and container_type_key in InferenceContainerTypeFamily .values ()
621
627
):
622
628
deployment_config = self .get_deployment_config (model_id )
623
629
config_params = (
@@ -640,10 +646,10 @@ def get_deployment_default_params(
640
646
return default_params
641
647
642
648
def validate_deployment_params (
643
- self ,
644
- model_id : str ,
645
- params : List [str ] = None ,
646
- container_family : str = None ,
649
+ self ,
650
+ model_id : str ,
651
+ params : List [str ] = None ,
652
+ container_family : str = None ,
647
653
) -> Dict :
648
654
"""Validate if the deployment parameters passed by the user can be overridden. Parameter values are not
649
655
validated, only param keys are validated.
@@ -666,8 +672,7 @@ def validate_deployment_params(
666
672
if params :
667
673
model = DataScienceModel .from_id (model_id )
668
674
container_type_key = self ._get_container_type_key (
669
- model = model ,
670
- container_family = container_family
675
+ model = model , container_family = container_family
671
676
)
672
677
673
678
container_config = get_container_config ()
@@ -689,9 +694,9 @@ def validate_deployment_params(
689
694
690
695
@staticmethod
691
696
def _find_restricted_params (
692
- default_params : Union [str , List [str ]],
693
- user_params : Union [str , List [str ]],
694
- container_family : str ,
697
+ default_params : Union [str , List [str ]],
698
+ user_params : Union [str , List [str ]],
699
+ container_family : str ,
695
700
) -> List [str ]:
696
701
"""Returns a list of restricted params that user chooses to override when creating an Aqua deployment.
697
702
The default parameters coming from the container index json file cannot be overridden.
0 commit comments