1
1
#!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
2
# Copyright (c) 2024 Oracle and/or its affiliates.
4
3
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
4
6
- import json
7
5
import logging
8
6
from typing import Dict , List , Union
9
7
10
- from oci .data_science .models import ModelDeployment
11
-
12
8
from ads .aqua .app import AquaApp , logger
13
9
from ads .aqua .common .enums import (
14
- Tags ,
15
- InferenceContainerParamType ,
16
10
InferenceContainerType ,
17
11
InferenceContainerTypeFamily ,
12
+ Tags ,
18
13
)
19
14
from ads .aqua .common .errors import AquaRuntimeError , AquaValueError
20
15
from ads .aqua .common .utils import (
16
+ get_combined_params ,
21
17
get_container_config ,
22
18
get_container_image ,
19
+ get_container_params_type ,
23
20
get_model_by_reference_paths ,
24
21
get_ocid_substring ,
25
- get_combined_params ,
26
22
get_params_dict ,
27
23
get_params_list ,
28
24
get_resource_name ,
38
34
from ads .aqua .data import AquaResourceIdentifier
39
35
from ads .aqua .finetuning .finetuning import FineTuneCustomMetadata
40
36
from ads .aqua .model import AquaModelApp
37
+ from ads .aqua .modeldeployment .constants import (
38
+ TGIInferenceRestrictedParams ,
39
+ VLLMInferenceRestrictedParams ,
40
+ )
41
41
from ads .aqua .modeldeployment .entities import (
42
42
AquaDeployment ,
43
43
AquaDeploymentDetail ,
44
44
ContainerSpec ,
45
45
)
46
- from ads .aqua .modeldeployment .constants import (
47
- VLLMInferenceRestrictedParams ,
48
- TGIInferenceRestrictedParams ,
49
- )
50
46
from ads .common .object_storage_details import ObjectStorageDetails
51
47
from ads .common .utils import get_log_links
52
48
from ads .config import (
@@ -187,24 +183,24 @@ def create(
187
183
model_name = aqua_model .custom_metadata_list .get (
188
184
FineTuneCustomMetadata .FINE_TUNE_SOURCE_NAME
189
185
).value
190
- except :
186
+ except ValueError as err :
191
187
raise AquaValueError (
192
188
f"Either { FineTuneCustomMetadata .FINE_TUNE_SOURCE } or { FineTuneCustomMetadata .FINE_TUNE_SOURCE_NAME } is missing "
193
189
f"from custom metadata for the model { config_source_id } "
194
- )
190
+ ) from err
195
191
196
192
# set up env vars
197
193
if not env_var :
198
- env_var = dict ()
194
+ env_var = {}
199
195
200
196
try :
201
197
model_path_prefix = aqua_model .custom_metadata_list .get (
202
198
MODEL_BY_REFERENCE_OSS_PATH_KEY
203
199
).value .rstrip ("/" )
204
- except ValueError :
200
+ except ValueError as err :
205
201
raise AquaValueError (
206
202
f"{ MODEL_BY_REFERENCE_OSS_PATH_KEY } key is not available in the custom metadata field."
207
- )
203
+ ) from err
208
204
209
205
if ObjectStorageDetails .is_oci_path (model_path_prefix ):
210
206
os_path = ObjectStorageDetails .from_path (model_path_prefix )
@@ -219,7 +215,7 @@ def create(
219
215
220
216
if not fine_tune_output_path :
221
217
raise AquaValueError (
222
- f "Fine tuned output path is not available in the model artifact."
218
+ "Fine tuned output path is not available in the model artifact."
223
219
)
224
220
225
221
os_path = ObjectStorageDetails .from_path (fine_tune_output_path )
@@ -232,7 +228,7 @@ def create(
232
228
container_type_key = aqua_model .custom_metadata_list .get (
233
229
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME
234
230
).value
235
- except ValueError :
231
+ except ValueError as err :
236
232
message = (
237
233
f"{ AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME } key is not available in the custom metadata field "
238
234
f"for model { aqua_model .id } ."
@@ -242,7 +238,7 @@ def create(
242
238
raise AquaValueError (
243
239
f"{ message } . For unverified Aqua models, container_family parameter should be "
244
240
f"set and value can be one of { ', ' .join (InferenceContainerTypeFamily .values ())} ."
245
- )
241
+ ) from err
246
242
container_type_key = container_family
247
243
try :
248
244
# Check if the container override flag is set. If set, then the user has chosen custom image
@@ -282,11 +278,12 @@ def create(
282
278
) # Give precendece to the input parameter
283
279
284
280
deployment_config = self .get_deployment_config (config_source_id )
285
- vllm_params = (
281
+
282
+ config_params = (
286
283
deployment_config .get ("configuration" , UNKNOWN_DICT )
287
284
.get (instance_shape , UNKNOWN_DICT )
288
285
.get ("parameters" , UNKNOWN_DICT )
289
- .get (InferenceContainerParamType . PARAM_TYPE_VLLM , UNKNOWN )
286
+ .get (get_container_params_type ( container_type_key ) , UNKNOWN )
290
287
)
291
288
292
289
# validate user provided params
@@ -301,7 +298,7 @@ def create(
301
298
f"and cannot be overridden or are invalid."
302
299
)
303
300
304
- deployment_params = get_combined_params (vllm_params , user_params )
301
+ deployment_params = get_combined_params (config_params , user_params )
305
302
306
303
if deployment_params :
307
304
params = f"{ params } { deployment_params } "
@@ -429,7 +426,7 @@ def list(self, **kwargs) -> List["AquaDeployment"]:
429
426
# tracks unique deployments that were listed in the user compartment
430
427
# we arbitrarily choose last 8 characters of OCID to identify MD in telemetry
431
428
self .telemetry .record_event_async (
432
- category = f "aqua/deployment" ,
429
+ category = "aqua/deployment" ,
433
430
action = "list" ,
434
431
detail = get_ocid_substring (deployment_id , key_len = 8 ),
435
432
value = state ,
@@ -574,25 +571,12 @@ def get_deployment_default_params(
574
571
container_type_key = container_type_key .lower ()
575
572
if container_type_key in InferenceContainerTypeFamily .values ():
576
573
deployment_config = self .get_deployment_config (model_id )
577
- config_parameters = (
574
+ params = (
578
575
deployment_config .get ("configuration" , UNKNOWN_DICT )
579
576
.get (instance_shape , UNKNOWN_DICT )
580
577
.get ("parameters" , UNKNOWN_DICT )
578
+ .get (get_container_params_type (container_type_key ))
581
579
)
582
- if InferenceContainerType .CONTAINER_TYPE_VLLM in container_type_key :
583
- params = config_parameters .get (
584
- InferenceContainerParamType .PARAM_TYPE_VLLM , UNKNOWN
585
- )
586
- elif InferenceContainerType .CONTAINER_TYPE_TGI in container_type_key :
587
- params = config_parameters .get (
588
- InferenceContainerParamType .PARAM_TYPE_TGI , UNKNOWN
589
- )
590
- else :
591
- params = UNKNOWN
592
- logger .debug (
593
- f"Default inference parameters are not available for the model { model_id } and "
594
- f"instance { instance_shape } ."
595
- )
596
580
if params :
597
581
# account for param that can have --arg but no values, e.g. --trust-remote-code
598
582
default_params .extend (get_params_list (params ))
@@ -629,7 +613,7 @@ def validate_deployment_params(
629
613
container_type_key = model .custom_metadata_list .get (
630
614
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME
631
615
).value
632
- except ValueError :
616
+ except ValueError as err :
633
617
message = (
634
618
f"{ AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME } key is not available in the custom metadata field "
635
619
f"for model { model_id } ."
@@ -640,7 +624,7 @@ def validate_deployment_params(
640
624
raise AquaValueError (
641
625
f"{ message } . For unverified Aqua models, container_family parameter should be "
642
626
f"set and value can be one of { ', ' .join (InferenceContainerTypeFamily .values ())} ."
643
- )
627
+ ) from err
644
628
container_type_key = container_family
645
629
646
630
container_config = get_container_config ()
@@ -658,7 +642,7 @@ def validate_deployment_params(
658
642
f"Parameters { restricted_params } are set by Aqua "
659
643
f"and cannot be overridden or are invalid."
660
644
)
661
- return dict ( valid = True )
645
+ return { " valid" : True }
662
646
663
647
@staticmethod
664
648
def _find_restricted_params (
@@ -689,7 +673,7 @@ def _find_restricted_params(
689
673
default_params_dict = get_params_dict (default_params )
690
674
user_params_dict = get_params_dict (user_params )
691
675
692
- for key , items in user_params_dict .items ():
676
+ for key , _items in user_params_dict .items ():
693
677
if (
694
678
key in default_params_dict
695
679
or (
@@ -701,6 +685,6 @@ def _find_restricted_params(
701
685
and key in TGIInferenceRestrictedParams
702
686
)
703
687
):
704
- restricted_params .append (key .lstrip ("-- " ))
688
+ restricted_params .append (key .lstrip ("-" ))
705
689
706
690
return restricted_params
0 commit comments