Skip to content

Commit 9d1a67b

Browse files
Accept container family as input for model creation and param validation (#862)
2 parents 227a177 + 114730a commit 9d1a67b

File tree

8 files changed

+184
-56
lines changed

8 files changed

+184
-56
lines changed

ads/aqua/common/enums.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,25 @@ class Tags(str, metaclass=ExtendedEnumMeta):
4040
AQUA_EVALUATION_MODEL_ID = "evaluation_model_id"
4141

4242

43+
class InferenceContainerType(str, metaclass=ExtendedEnumMeta):
44+
CONTAINER_TYPE_VLLM = "vllm"
45+
CONTAINER_TYPE_TGI = "tgi"
46+
47+
48+
class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
49+
AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
50+
AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
51+
52+
53+
class InferenceContainerParamType(str, metaclass=ExtendedEnumMeta):
54+
PARAM_TYPE_VLLM = "VLLM_PARAMS"
55+
PARAM_TYPE_TGI = "TGI_PARAMS"
56+
57+
58+
class HuggingFaceTags(str, metaclass=ExtendedEnumMeta):
59+
TEXT_GENERATION_INFERENCE = "text-generation-inference"
60+
61+
4362
class RqsAdditionalDetails(str, metaclass=ExtendedEnumMeta):
4463
METADATA = "metadata"
4564
CREATED_BY = "createdBy"

ads/aqua/extension/deployment_handler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def post(self, *args, **kwargs):
9999
server_port = input_data.get("server_port")
100100
health_check_port = input_data.get("health_check_port")
101101
env_var = input_data.get("env_var")
102+
container_family = input_data.get("container_family")
102103

103104
self.finish(
104105
AquaDeploymentApp().create(
@@ -117,6 +118,7 @@ def post(self, *args, **kwargs):
117118
server_port=server_port,
118119
health_check_port=health_check_port,
119120
env_var=env_var,
121+
container_family=container_family,
120122
)
121123
)
122124

@@ -245,10 +247,12 @@ def post(self, *args, **kwargs):
245247
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model_id"))
246248

247249
params = input_data.get("params")
250+
container_family = input_data.get("container_family")
248251
return self.finish(
249252
AquaDeploymentApp().validate_deployment_params(
250253
model_id=model_id,
251254
params=params,
255+
container_family=container_family,
252256
)
253257
)
254258

ads/aqua/model/model.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
1515
from ads.aqua.app import AquaApp
16-
from ads.aqua.common.enums import Tags
16+
from ads.aqua.common.enums import Tags, HuggingFaceTags, InferenceContainerTypeFamily
1717
from ads.aqua.common.errors import AquaRuntimeError
1818
from ads.aqua.common.utils import (
1919
create_word_icon,
@@ -621,7 +621,21 @@ def _create_model_catalog_entry(
621621
)
622622
else:
623623
logger.warn(
624-
f"Require Inference container information. Model: {model_name} does not have associated inference container defaults. Check docs for more information on how to pass inference container. Proceeding with model registration without the fine-tuning container information. This model will not be available for fine tuning."
624+
f"Proceeding with model registration without the fine-tuning container information. "
625+
f"This model will not be available for fine tuning."
626+
)
627+
628+
if not inference_container:
629+
inference_container = (
630+
InferenceContainerTypeFamily.AQUA_TGI_CONTAINER_FAMILY
631+
if model_info
632+
and model_info.tags
633+
and HuggingFaceTags.TEXT_GENERATION_INFERENCE in model_info.tags
634+
else InferenceContainerTypeFamily.AQUA_VLLM_CONTAINER_FAMILY
635+
)
636+
logger.info(
637+
f"Model: {model_name} does not have associated inference container defaults. "
638+
f"{inference_container} will be used instead."
625639
)
626640
metadata.add(
627641
key=AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME,
@@ -774,7 +788,7 @@ def register(
774788
break
775789
if i == retry:
776790
raise Exception(
777-
"Could not download the model {model_name} from https://huggingface.co with message {huggingface_download}"
791+
f"Could not download the model {model_name} from https://huggingface.co with message {huggingface_download_err_message}"
778792
)
779793
os.makedirs(local_dir, exist_ok=True)
780794
# Copy the model from the cache to destination

ads/aqua/modeldeployment/constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@
1010
This module contains constants used in Aqua Model Deployment.
1111
"""
1212

13-
VLLMInferenceRestrictedParams = {"tensor-parallel-size"}
13+
VLLMInferenceRestrictedParams = {"--tensor-parallel-size"}
14+
TGIInferenceRestrictedParams = {"--port"}

ads/aqua/modeldeployment/deployment.py

Lines changed: 68 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010
from oci.data_science.models import ModelDeployment
1111

1212
from ads.aqua.app import AquaApp, logger
13-
from ads.aqua.common.enums import Tags
13+
from ads.aqua.common.enums import (
14+
Tags,
15+
InferenceContainerParamType,
16+
InferenceContainerType,
17+
InferenceContainerTypeFamily,
18+
)
1419
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1520
from ads.aqua.common.utils import (
1621
get_container_config,
@@ -38,11 +43,9 @@
3843
AquaDeploymentDetail,
3944
ContainerSpec,
4045
)
41-
from ads.aqua.modeldeployment.constants import VLLMInferenceRestrictedParams
42-
from ads.aqua.modeldeployment.enums import (
43-
InferenceContainerParamType,
44-
InferenceContainerType,
45-
InferenceContainerTypeKey,
46+
from ads.aqua.modeldeployment.constants import (
47+
VLLMInferenceRestrictedParams,
48+
TGIInferenceRestrictedParams,
4649
)
4750
from ads.common.object_storage_details import ObjectStorageDetails
4851
from ads.common.utils import get_log_links
@@ -106,6 +109,7 @@ def create(
106109
server_port: int = None,
107110
health_check_port: int = None,
108111
env_var: Dict = None,
112+
container_family: str = None,
109113
) -> "AquaDeployment":
110114
"""
111115
Creates a new Aqua deployment
@@ -144,6 +148,8 @@ def create(
144148
The health check port for docker container image.
145149
env_var : dict, optional
146150
Environment variable for the deployment, by default None.
151+
container_family: str
152+
The image family of model deployment container runtime. Required for unverified Aqua models.
147153
Returns
148154
-------
149155
AquaDeployment
@@ -227,9 +233,17 @@ def create(
227233
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME
228234
).value
229235
except ValueError:
230-
raise AquaValueError(
231-
f"{AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} key is not available in the custom metadata field for model {aqua_model.id}"
236+
message = (
237+
f"{AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} key is not available in the custom metadata field "
238+
f"for model {aqua_model.id}."
232239
)
240+
logger.debug(message)
241+
if not container_family:
242+
raise AquaValueError(
243+
f"{message}. For unverified Aqua models, container_family parameter should be "
244+
f"set and value can be one of {', '.join(InferenceContainerTypeFamily.values())}."
245+
)
246+
container_type_key = container_family
233247
try:
234248
# Check if the container override flag is set. If set, then the user has chosen custom image
235249
if aqua_model.custom_metadata_list.get(
@@ -275,13 +289,12 @@ def create(
275289
.get(InferenceContainerParamType.PARAM_TYPE_VLLM, UNKNOWN)
276290
)
277291

278-
# todo: add support for tgi once parameters are added to configs. _find_restricted_params can take in
279-
# additional parameter container_type_key and should validate against TGIInferenceRestrictedParams set for
280-
# restricted params.
281292
# validate user provided params
282293
user_params = env_var.get("PARAMS", UNKNOWN)
283294
if user_params:
284-
restricted_params = self._find_restricted_params(params, user_params)
295+
restricted_params = self._find_restricted_params(
296+
params, user_params, container_type_key
297+
)
285298
if restricted_params:
286299
raise AquaValueError(
287300
f"Parameters {restricted_params} are set by Aqua "
@@ -559,7 +572,7 @@ def get_deployment_default_params(
559572

560573
if container_type_key:
561574
container_type_key = container_type_key.lower()
562-
if container_type_key in InferenceContainerTypeKey.values():
575+
if container_type_key in InferenceContainerTypeFamily.values():
563576
deployment_config = self.get_deployment_config(model_id)
564577
config_parameters = (
565578
deployment_config.get("configuration", UNKNOWN_DICT)
@@ -587,7 +600,10 @@ def get_deployment_default_params(
587600
return default_params
588601

589602
def validate_deployment_params(
590-
self, model_id: str, params: List[str] = None
603+
self,
604+
model_id: str,
605+
params: List[str] = None,
606+
container_family: str = None,
591607
) -> Dict:
592608
"""Validate if the deployment parameters passed by the user can be overridden. Parameter values are not
593609
validated, only param keys are validated.
@@ -596,9 +612,10 @@ def validate_deployment_params(
596612
----------
597613
model_id: str
598614
The OCID of the Aqua model.
599-
600615
params : List[str], optional
601616
Params passed by the user.
617+
container_family: str
618+
The image family of model deployment container runtime. Required for unverified Aqua models.
602619
603620
Returns
604621
-------
@@ -613,18 +630,28 @@ def validate_deployment_params(
613630
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME
614631
).value
615632
except ValueError:
616-
container_type_key = UNKNOWN
617-
logger.debug(
618-
f"{AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} key is not available in the custom metadata field for model {model_id}."
633+
message = (
634+
f"{AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} key is not available in the custom metadata field "
635+
f"for model {model_id}."
619636
)
620-
if container_type_key:
621-
container_config = get_container_config()
622-
container_spec = container_config.get(
623-
ContainerSpec.CONTAINER_SPEC, {}
624-
).get(container_type_key, {})
625-
cli_params = container_spec.get(ContainerSpec.CLI_PARM, "")
637+
logger.debug(message)
626638

627-
restricted_params = self._find_restricted_params(cli_params, params)
639+
if not container_family:
640+
raise AquaValueError(
641+
f"{message}. For unverified Aqua models, container_family parameter should be "
642+
f"set and value can be one of {', '.join(InferenceContainerTypeFamily.values())}."
643+
)
644+
container_type_key = container_family
645+
646+
container_config = get_container_config()
647+
container_spec = container_config.get(ContainerSpec.CONTAINER_SPEC, {}).get(
648+
container_type_key, {}
649+
)
650+
cli_params = container_spec.get(ContainerSpec.CLI_PARM, "")
651+
652+
restricted_params = self._find_restricted_params(
653+
cli_params, params, container_type_key
654+
)
628655

629656
if restricted_params:
630657
raise AquaValueError(
@@ -635,7 +662,9 @@ def validate_deployment_params(
635662

636663
@staticmethod
637664
def _find_restricted_params(
638-
default_params: Union[str, List[str]], user_params: Union[str, List[str]]
665+
default_params: Union[str, List[str]],
666+
user_params: Union[str, List[str]],
667+
container_family: str,
639668
) -> List[str]:
640669
"""Returns a list of restricted params that user chooses to override when creating an Aqua deployment.
641670
The default parameters coming from the container index json file cannot be overridden. In addition to this,
@@ -647,6 +676,8 @@ def _find_restricted_params(
647676
Inference container parameter string with default values.
648677
user_params:
649678
Inference container parameter string with user provided values.
679+
container_family: str
680+
The image family of model deployment container runtime.
650681
651682
Returns
652683
-------
@@ -659,7 +690,17 @@ def _find_restricted_params(
659690
user_params_dict = get_params_dict(user_params)
660691

661692
for key, items in user_params_dict.items():
662-
if key in default_params_dict or key in VLLMInferenceRestrictedParams:
693+
if (
694+
key in default_params_dict
695+
or (
696+
InferenceContainerType.CONTAINER_TYPE_VLLM in container_family
697+
and key in VLLMInferenceRestrictedParams
698+
)
699+
or (
700+
InferenceContainerType.CONTAINER_TYPE_TGI in container_family
701+
and key in TGIInferenceRestrictedParams
702+
)
703+
):
663704
restricted_params.append(key.lstrip("--"))
664705

665706
return restricted_params

ads/aqua/modeldeployment/enums.py

Lines changed: 0 additions & 22 deletions
This file was deleted.

tests/unitary/with_extras/aqua/test_deployment.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,14 @@ def test_get_deployment_default_params(
494494
"custom-container-key",
495495
["--max-model-len 4096", "--seed 42", "--trust-remote-code"],
496496
),
497+
(
498+
"odsc-vllm-serving",
499+
["--tensor-parallel-size 2"],
500+
),
501+
(
502+
"odsc-tgi-serving",
503+
["--port 8080"],
504+
),
497505
]
498506
)
499507
@patch("ads.model.datascience_model.DataScienceModel.from_id")
@@ -530,6 +538,55 @@ def test_validate_deployment_params(
530538
)
531539
assert result["valid"] is True
532540

541+
@parameterized.expand(
542+
[
543+
(
544+
"odsc-vllm-serving",
545+
["--max-model-len 4096"],
546+
),
547+
(
548+
"odsc-tgi-serving",
549+
["--max_stop_sequences 5"],
550+
),
551+
(
552+
"",
553+
["--some_random_key some_random_value"],
554+
),
555+
]
556+
)
557+
@patch("ads.model.datascience_model.DataScienceModel.from_id")
558+
@patch("ads.aqua.modeldeployment.deployment.get_container_config")
559+
def test_validate_deployment_params_for_unverified_models(
560+
self, container_type_key, params, mock_get_container_config, mock_from_id
561+
):
562+
"""Test to check if container family is used when metadata does not have image information
563+
for unverified models."""
564+
mock_model = MagicMock()
565+
mock_model.custom_metadata_list = ModelCustomMetadata()
566+
mock_from_id.return_value = mock_model
567+
568+
container_index_json = os.path.join(
569+
self.curr_dir, "test_data/ui/container_index.json"
570+
)
571+
with open(container_index_json, "r") as _file:
572+
container_index_config = json.load(_file)
573+
mock_get_container_config.return_value = container_index_config
574+
575+
if container_type_key in {"odsc-vllm-serving", "odsc-tgi-serving"} and params:
576+
result = self.app.validate_deployment_params(
577+
model_id="mock-model-id",
578+
params=params,
579+
container_family=container_type_key,
580+
)
581+
assert result["valid"] is True
582+
else:
583+
with pytest.raises(AquaValueError):
584+
self.app.validate_deployment_params(
585+
model_id="mock-model-id",
586+
params=params,
587+
container_family=container_type_key,
588+
)
589+
533590

534591
class TestMDInferenceResponse(unittest.TestCase):
535592
def setUp(self):

0 commit comments

Comments
 (0)