Skip to content

Commit 54f9435

Browse files
bug fix for override params issue and ruff updates
1 parent e1b222b commit 54f9435

File tree

4 files changed

+81
-75
lines changed

4 files changed

+81
-75
lines changed

ads/aqua/common/utils.py

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54
"""AQUA utils and constants."""
5+
66
import asyncio
77
import base64
88
import json
@@ -19,13 +19,28 @@
1919
import oci
2020
from oci.data_science.models import JobRun, Model
2121

22-
from ads.aqua.common.enums import RqsAdditionalDetails
22+
from ads.aqua.common.enums import (
23+
InferenceContainerParamType,
24+
InferenceContainerType,
25+
RqsAdditionalDetails,
26+
)
2327
from ads.aqua.common.errors import (
2428
AquaFileNotFoundError,
2529
AquaRuntimeError,
2630
AquaValueError,
2731
)
28-
from ads.aqua.constants import *
32+
from ads.aqua.constants import (
33+
AQUA_GA_LIST,
34+
COMPARTMENT_MAPPING_KEY,
35+
CONSOLE_LINK_RESOURCE_TYPE_MAPPING,
36+
CONTAINER_INDEX,
37+
MAXIMUM_ALLOWED_DATASET_IN_BYTE,
38+
MODEL_BY_REFERENCE_OSS_PATH_KEY,
39+
SERVICE_MANAGED_CONTAINER_URI_SCHEME,
40+
SUPPORTED_FILE_FORMATS,
41+
UNKNOWN,
42+
UNKNOWN_JSON_STR,
43+
)
2944
from ads.aqua.data import AquaResourceIdentifier
3045
from ads.common.auth import default_signer
3146
from ads.common.decorator.threaded import threaded
@@ -74,15 +89,15 @@ def get_status(evaluation_status: str, job_run_status: str = None):
7489

7590
status = LifecycleStatus.UNKNOWN
7691
if evaluation_status == Model.LIFECYCLE_STATE_ACTIVE:
77-
if (
78-
job_run_status == JobRun.LIFECYCLE_STATE_IN_PROGRESS
79-
or job_run_status == JobRun.LIFECYCLE_STATE_ACCEPTED
80-
):
92+
if job_run_status in {
93+
JobRun.LIFECYCLE_STATE_IN_PROGRESS,
94+
JobRun.LIFECYCLE_STATE_ACCEPTED,
95+
}:
8196
status = JobRun.LIFECYCLE_STATE_IN_PROGRESS
82-
elif (
83-
job_run_status == JobRun.LIFECYCLE_STATE_FAILED
84-
or job_run_status == JobRun.LIFECYCLE_STATE_NEEDS_ATTENTION
85-
):
97+
elif job_run_status in {
98+
JobRun.LIFECYCLE_STATE_FAILED,
99+
JobRun.LIFECYCLE_STATE_NEEDS_ATTENTION,
100+
}:
86101
status = JobRun.LIFECYCLE_STATE_FAILED
87102
else:
88103
status = job_run_status
@@ -199,10 +214,7 @@ def read_file(file_path: str, **kwargs) -> str:
199214
@threaded()
200215
def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
201216
artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
202-
if artifact_path.startswith("oci://"):
203-
signer = default_signer()
204-
else:
205-
signer = {}
217+
signer = default_signer() if artifact_path.startswith("oci://") else {}
206218
config = json.loads(
207219
read_file(file_path=artifact_path, auth=signer, **kwargs) or UNKNOWN_JSON_STR
208220
)
@@ -448,7 +460,7 @@ def _build_resource_identifier(
448460

449461

450462
def _get_experiment_info(
451-
model: Union[oci.resource_search.models.ResourceSummary, DataScienceModel]
463+
model: Union[oci.resource_search.models.ResourceSummary, DataScienceModel],
452464
) -> tuple:
453465
"""Returns ocid and name of the experiment."""
454466
return (
@@ -609,7 +621,7 @@ def extract_id_and_name_from_tag(tag: str):
609621
base_model_name = UNKNOWN
610622
try:
611623
base_model_ocid, base_model_name = tag.split("#")
612-
except:
624+
except Exception:
613625
pass
614626

615627
if not (is_valid_ocid(base_model_ocid) and base_model_name):
@@ -646,7 +658,7 @@ def get_resource_name(ocid: str) -> str:
646658
try:
647659
resource = query_resource(ocid, return_all=False)
648660
name = resource.display_name if resource else UNKNOWN
649-
except:
661+
except Exception:
650662
name = UNKNOWN
651663
return name
652664

@@ -670,8 +682,8 @@ def get_model_by_reference_paths(model_file_description: dict):
670682

671683
if not models:
672684
raise AquaValueError(
673-
f"Model path is not available in the model json artifact. "
674-
f"Please check if the model created by reference has the correct artifact."
685+
"Model path is not available in the model json artifact. "
686+
"Please check if the model created by reference has the correct artifact."
675687
)
676688

677689
if len(models) > 0:
@@ -848,3 +860,24 @@ def copy_model_config(artifact_path: str, os_path: str, auth: dict = None):
848860
except Exception as ex:
849861
logger.debug(ex)
850862
logger.debug(f"Failed to copy config folder from {artifact_path} to {os_path}.")
863+
864+
865+
def get_container_params_type(container_type_name: str):
866+
"""The utility function accepts the deployment container type name and returns the corresponding params name.
867+
Parameters
868+
----------
869+
container_type_name: str
870+
type of deployment container, like odsc-vllm-serving or odsc-tgi-serving.
871+
872+
Returns
873+
-------
874+
InferenceContainerParamType value
875+
876+
"""
877+
# check substring instead of direct match in case container_type_name changes in the future
878+
if InferenceContainerType.CONTAINER_TYPE_VLLM in container_type_name:
879+
return InferenceContainerParamType.PARAM_TYPE_VLLM
880+
elif InferenceContainerType.CONTAINER_TYPE_TGI in container_type_name:
881+
return InferenceContainerParamType.PARAM_TYPE_TGI
882+
else:
883+
return UNKNOWN

ads/aqua/modeldeployment/deployment.py

Lines changed: 27 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,24 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

6-
import json
75
import logging
86
from typing import Dict, List, Union
97

10-
from oci.data_science.models import ModelDeployment
11-
128
from ads.aqua.app import AquaApp, logger
139
from ads.aqua.common.enums import (
14-
Tags,
15-
InferenceContainerParamType,
1610
InferenceContainerType,
1711
InferenceContainerTypeFamily,
12+
Tags,
1813
)
1914
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
2015
from ads.aqua.common.utils import (
16+
get_combined_params,
2117
get_container_config,
2218
get_container_image,
19+
get_container_params_type,
2320
get_model_by_reference_paths,
2421
get_ocid_substring,
25-
get_combined_params,
2622
get_params_dict,
2723
get_params_list,
2824
get_resource_name,
@@ -38,15 +34,15 @@
3834
from ads.aqua.data import AquaResourceIdentifier
3935
from ads.aqua.finetuning.finetuning import FineTuneCustomMetadata
4036
from ads.aqua.model import AquaModelApp
37+
from ads.aqua.modeldeployment.constants import (
38+
TGIInferenceRestrictedParams,
39+
VLLMInferenceRestrictedParams,
40+
)
4141
from ads.aqua.modeldeployment.entities import (
4242
AquaDeployment,
4343
AquaDeploymentDetail,
4444
ContainerSpec,
4545
)
46-
from ads.aqua.modeldeployment.constants import (
47-
VLLMInferenceRestrictedParams,
48-
TGIInferenceRestrictedParams,
49-
)
5046
from ads.common.object_storage_details import ObjectStorageDetails
5147
from ads.common.utils import get_log_links
5248
from ads.config import (
@@ -187,24 +183,24 @@ def create(
187183
model_name = aqua_model.custom_metadata_list.get(
188184
FineTuneCustomMetadata.FINE_TUNE_SOURCE_NAME
189185
).value
190-
except:
186+
except ValueError as err:
191187
raise AquaValueError(
192188
f"Either {FineTuneCustomMetadata.FINE_TUNE_SOURCE} or {FineTuneCustomMetadata.FINE_TUNE_SOURCE_NAME} is missing "
193189
f"from custom metadata for the model {config_source_id}"
194-
)
190+
) from err
195191

196192
# set up env vars
197193
if not env_var:
198-
env_var = dict()
194+
env_var = {}
199195

200196
try:
201197
model_path_prefix = aqua_model.custom_metadata_list.get(
202198
MODEL_BY_REFERENCE_OSS_PATH_KEY
203199
).value.rstrip("/")
204-
except ValueError:
200+
except ValueError as err:
205201
raise AquaValueError(
206202
f"{MODEL_BY_REFERENCE_OSS_PATH_KEY} key is not available in the custom metadata field."
207-
)
203+
) from err
208204

209205
if ObjectStorageDetails.is_oci_path(model_path_prefix):
210206
os_path = ObjectStorageDetails.from_path(model_path_prefix)
@@ -219,7 +215,7 @@ def create(
219215

220216
if not fine_tune_output_path:
221217
raise AquaValueError(
222-
f"Fine tuned output path is not available in the model artifact."
218+
"Fine tuned output path is not available in the model artifact."
223219
)
224220

225221
os_path = ObjectStorageDetails.from_path(fine_tune_output_path)
@@ -232,7 +228,7 @@ def create(
232228
container_type_key = aqua_model.custom_metadata_list.get(
233229
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME
234230
).value
235-
except ValueError:
231+
except ValueError as err:
236232
message = (
237233
f"{AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} key is not available in the custom metadata field "
238234
f"for model {aqua_model.id}."
@@ -242,7 +238,7 @@ def create(
242238
raise AquaValueError(
243239
f"{message}. For unverified Aqua models, container_family parameter should be "
244240
f"set and value can be one of {', '.join(InferenceContainerTypeFamily.values())}."
245-
)
241+
) from err
246242
container_type_key = container_family
247243
try:
248244
# Check if the container override flag is set. If set, then the user has chosen custom image
@@ -282,11 +278,12 @@ def create(
282278
) # Give precendece to the input parameter
283279

284280
deployment_config = self.get_deployment_config(config_source_id)
285-
vllm_params = (
281+
282+
config_params = (
286283
deployment_config.get("configuration", UNKNOWN_DICT)
287284
.get(instance_shape, UNKNOWN_DICT)
288285
.get("parameters", UNKNOWN_DICT)
289-
.get(InferenceContainerParamType.PARAM_TYPE_VLLM, UNKNOWN)
286+
.get(get_container_params_type(container_type_key), UNKNOWN)
290287
)
291288

292289
# validate user provided params
@@ -301,7 +298,7 @@ def create(
301298
f"and cannot be overridden or are invalid."
302299
)
303300

304-
deployment_params = get_combined_params(vllm_params, user_params)
301+
deployment_params = get_combined_params(config_params, user_params)
305302

306303
if deployment_params:
307304
params = f"{params} {deployment_params}"
@@ -429,7 +426,7 @@ def list(self, **kwargs) -> List["AquaDeployment"]:
429426
# tracks unique deployments that were listed in the user compartment
430427
# we arbitrarily choose last 8 characters of OCID to identify MD in telemetry
431428
self.telemetry.record_event_async(
432-
category=f"aqua/deployment",
429+
category="aqua/deployment",
433430
action="list",
434431
detail=get_ocid_substring(deployment_id, key_len=8),
435432
value=state,
@@ -574,25 +571,12 @@ def get_deployment_default_params(
574571
container_type_key = container_type_key.lower()
575572
if container_type_key in InferenceContainerTypeFamily.values():
576573
deployment_config = self.get_deployment_config(model_id)
577-
config_parameters = (
574+
params = (
578575
deployment_config.get("configuration", UNKNOWN_DICT)
579576
.get(instance_shape, UNKNOWN_DICT)
580577
.get("parameters", UNKNOWN_DICT)
578+
.get(get_container_params_type(container_type_key))
581579
)
582-
if InferenceContainerType.CONTAINER_TYPE_VLLM in container_type_key:
583-
params = config_parameters.get(
584-
InferenceContainerParamType.PARAM_TYPE_VLLM, UNKNOWN
585-
)
586-
elif InferenceContainerType.CONTAINER_TYPE_TGI in container_type_key:
587-
params = config_parameters.get(
588-
InferenceContainerParamType.PARAM_TYPE_TGI, UNKNOWN
589-
)
590-
else:
591-
params = UNKNOWN
592-
logger.debug(
593-
f"Default inference parameters are not available for the model {model_id} and "
594-
f"instance {instance_shape}."
595-
)
596580
if params:
597581
# account for param that can have --arg but no values, e.g. --trust-remote-code
598582
default_params.extend(get_params_list(params))
@@ -629,7 +613,7 @@ def validate_deployment_params(
629613
container_type_key = model.custom_metadata_list.get(
630614
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME
631615
).value
632-
except ValueError:
616+
except ValueError as err:
633617
message = (
634618
f"{AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} key is not available in the custom metadata field "
635619
f"for model {model_id}."
@@ -640,7 +624,7 @@ def validate_deployment_params(
640624
raise AquaValueError(
641625
f"{message}. For unverified Aqua models, container_family parameter should be "
642626
f"set and value can be one of {', '.join(InferenceContainerTypeFamily.values())}."
643-
)
627+
) from err
644628
container_type_key = container_family
645629

646630
container_config = get_container_config()
@@ -658,7 +642,7 @@ def validate_deployment_params(
658642
f"Parameters {restricted_params} are set by Aqua "
659643
f"and cannot be overridden or are invalid."
660644
)
661-
return dict(valid=True)
645+
return {"valid": True}
662646

663647
@staticmethod
664648
def _find_restricted_params(
@@ -689,7 +673,7 @@ def _find_restricted_params(
689673
default_params_dict = get_params_dict(default_params)
690674
user_params_dict = get_params_dict(user_params)
691675

692-
for key, items in user_params_dict.items():
676+
for key, _items in user_params_dict.items():
693677
if (
694678
key in default_params_dict
695679
or (
@@ -701,6 +685,6 @@ def _find_restricted_params(
701685
and key in TGIInferenceRestrictedParams
702686
)
703687
):
704-
restricted_params.append(key.lstrip("--"))
688+
restricted_params.append(key.lstrip("-"))
705689

706690
return restricted_params

tests/unitary/with_extras/aqua/test_data/deployment/deployment_config.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"configuration": {
33
"VM.GPU.A10.1": {
44
"parameters": {
5+
"TGI_PARAMS": "--max-stop-sequences 6",
56
"VLLM_PARAMS": "--max-model-len 4096"
67
}
78
}

tests/unitary/with_extras/aqua/test_deployment.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -352,22 +352,10 @@ def test_create_deployment_for_fine_tuned_model(
352352
):
353353
"""Test to create a deployment for fine-tuned model"""
354354

355-
# todo: DataScienceModel.from_yaml should update model_file_description attribute, current workaround is to
356-
# load using with_model_file_description property.
357-
def yaml_to_json(input_file):
358-
with open(input_file, "r") as f:
359-
return yaml.safe_load(f)
360-
361355
aqua_model = os.path.join(
362356
self.curr_dir, "test_data/deployment/aqua_finetuned_model.yaml"
363357
)
364-
model_description_json = json.dumps(
365-
yaml_to_json(aqua_model)["spec"]["modelDescription"]
366-
)
367358
datascience_model = DataScienceModel.from_yaml(uri=aqua_model)
368-
datascience_model.with_model_file_description(
369-
json_string=model_description_json
370-
)
371359
mock_create.return_value = datascience_model
372360

373361
config_json = os.path.join(

0 commit comments

Comments
 (0)