Skip to content

Commit 0010fd7

Browse files
accept container family as input
1 parent 2502185 commit 0010fd7

File tree

6 files changed

+152
-26
lines changed

6 files changed

+152
-26
lines changed

ads/aqua/extension/deployment_handler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def post(self, *args, **kwargs):
9999
server_port = input_data.get("server_port")
100100
health_check_port = input_data.get("health_check_port")
101101
env_var = input_data.get("env_var")
102+
container_family = input_data.get("container_family")
102103

103104
self.finish(
104105
AquaDeploymentApp().create(
@@ -117,6 +118,7 @@ def post(self, *args, **kwargs):
117118
server_port=server_port,
118119
health_check_port=health_check_port,
119120
env_var=env_var,
121+
container_family=container_family,
120122
)
121123
)
122124

@@ -245,10 +247,12 @@ def post(self, *args, **kwargs):
245247
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model_id"))
246248

247249
params = input_data.get("params")
250+
container_family = input_data.get("container_family")
248251
return self.finish(
249252
AquaDeploymentApp().validate_deployment_params(
250253
model_id=model_id,
251254
params=params,
255+
container_family=container_family,
252256
)
253257
)
254258

ads/aqua/model/model.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from ads.aqua.model.constants import *
4040
from ads.aqua.model.entities import *
41+
from ads.aqua.modeldeployment.enums import InferenceContainerTypeKey
4142
from ads.common.auth import default_signer
4243
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
4344
from ads.common.utils import get_console_link
@@ -621,7 +622,15 @@ def _create_model_catalog_entry(
621622
)
622623
else:
623624
logger.warn(
624-
f"Require Inference container information. Model: {model_name} does not have associated inference container defaults. Check docs for more information on how to pass inference container. Proceeding with model registration without the fine-tuning container information. This model will not be available for fine tuning."
625+
f"Proceeding with model registration without the fine-tuning container information. "
626+
f"This model will not be available for fine tuning."
627+
)
628+
629+
if not inference_container:
630+
inference_container = InferenceContainerTypeKey.AQUA_TGI_CONTAINER_KEY
631+
logger.info(
632+
f"Model: {model_name} does not have associated inference container defaults. "
633+
f"{inference_container} will be used instead."
625634
)
626635
metadata.add(
627636
key=AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME,

ads/aqua/modeldeployment/constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@
1010
This module contains constants used in Aqua Model Deployment.
1111
"""
1212

13-
VLLMInferenceRestrictedParams = {"tensor-parallel-size"}
13+
VLLMInferenceRestrictedParams = {"--tensor-parallel-size"}
14+
TGIInferenceRestrictedParams = {"--port"}

ads/aqua/modeldeployment/deployment.py

Lines changed: 62 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@
3838
AquaDeploymentDetail,
3939
ContainerSpec,
4040
)
41-
from ads.aqua.modeldeployment.constants import VLLMInferenceRestrictedParams
41+
from ads.aqua.modeldeployment.constants import (
42+
VLLMInferenceRestrictedParams,
43+
TGIInferenceRestrictedParams,
44+
)
4245
from ads.aqua.modeldeployment.enums import (
4346
InferenceContainerParamType,
4447
InferenceContainerType,
@@ -106,6 +109,7 @@ def create(
106109
server_port: int = None,
107110
health_check_port: int = None,
108111
env_var: Dict = None,
112+
container_family: str = None,
109113
) -> "AquaDeployment":
110114
"""
111115
Creates a new Aqua deployment
@@ -144,6 +148,8 @@ def create(
144148
The health check port for docker container image.
145149
env_var : dict, optional
146150
Environment variable for the deployment, by default None.
151+
container_family: str
152+
The image family of model deployment container runtime. Required for unverified Aqua models.
147153
Returns
148154
-------
149155
AquaDeployment
@@ -227,9 +233,17 @@ def create(
227233
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME
228234
).value
229235
except ValueError:
230-
raise AquaValueError(
231-
f"{AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} key is not available in the custom metadata field for model {aqua_model.id}"
236+
message = (
237+
f"{AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} key is not available in the custom metadata field "
238+
f"for model {aqua_model.id}."
232239
)
240+
logger.debug(message)
241+
if not container_family:
242+
raise AquaValueError(
243+
f"{message}. For unverified Aqua models, container_family parameter should be "
244+
f"set and value can be one of {', '.join(InferenceContainerTypeKey.values())}."
245+
)
246+
container_type_key = container_family
233247
try:
234248
# Check if the container override flag is set. If set, then the user has chosen custom image
235249
if aqua_model.custom_metadata_list.get(
@@ -275,13 +289,12 @@ def create(
275289
.get(InferenceContainerParamType.PARAM_TYPE_VLLM, UNKNOWN)
276290
)
277291

278-
# todo: add support for tgi once parameters are added to configs. _find_restricted_params can take in
279-
# additional parameter container_type_key and should validate against TGIInferenceRestrictedParams set for
280-
# restricted params.
281292
# validate user provided params
282293
user_params = env_var.get("PARAMS", UNKNOWN)
283294
if user_params:
284-
restricted_params = self._find_restricted_params(params, user_params)
295+
restricted_params = self._find_restricted_params(
296+
params, user_params, container_type_key
297+
)
285298
if restricted_params:
286299
raise AquaValueError(
287300
f"Parameters {restricted_params} are set by Aqua "
@@ -587,7 +600,10 @@ def get_deployment_default_params(
587600
return default_params
588601

589602
def validate_deployment_params(
590-
self, model_id: str, params: List[str] = None
603+
self,
604+
model_id: str,
605+
params: List[str] = None,
606+
container_family: str = None,
591607
) -> Dict:
592608
"""Validate if the deployment parameters passed by the user can be overridden. Parameter values are not
593609
validated, only param keys are validated.
@@ -596,9 +612,10 @@ def validate_deployment_params(
596612
----------
597613
model_id: str
598614
The OCID of the Aqua model.
599-
600615
params : List[str], optional
601616
Params passed by the user.
617+
container_family: str
618+
The image family of model deployment container runtime. Required for unverified Aqua models.
602619
603620
Returns
604621
-------
@@ -613,18 +630,28 @@ def validate_deployment_params(
613630
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME
614631
).value
615632
except ValueError:
616-
container_type_key = UNKNOWN
617-
logger.debug(
618-
f"{AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} key is not available in the custom metadata field for model {model_id}."
633+
message = (
634+
f"{AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} key is not available in the custom metadata field "
635+
f"for model {model_id}."
619636
)
620-
if container_type_key:
621-
container_config = get_container_config()
622-
container_spec = container_config.get(
623-
ContainerSpec.CONTAINER_SPEC, {}
624-
).get(container_type_key, {})
625-
cli_params = container_spec.get(ContainerSpec.CLI_PARM, "")
637+
logger.debug(message)
626638

627-
restricted_params = self._find_restricted_params(cli_params, params)
639+
if not container_family:
640+
raise AquaValueError(
641+
f"{message}. For unverified Aqua models, container_family parameter should be "
642+
f"set and value can be one of {', '.join(InferenceContainerTypeKey.values())}."
643+
)
644+
container_type_key = container_family
645+
646+
container_config = get_container_config()
647+
container_spec = container_config.get(ContainerSpec.CONTAINER_SPEC, {}).get(
648+
container_type_key, {}
649+
)
650+
cli_params = container_spec.get(ContainerSpec.CLI_PARM, "")
651+
652+
restricted_params = self._find_restricted_params(
653+
cli_params, params, container_type_key
654+
)
628655

629656
if restricted_params:
630657
raise AquaValueError(
@@ -635,7 +662,9 @@ def validate_deployment_params(
635662

636663
@staticmethod
637664
def _find_restricted_params(
638-
default_params: Union[str, List[str]], user_params: Union[str, List[str]]
665+
default_params: Union[str, List[str]],
666+
user_params: Union[str, List[str]],
667+
container_family: str,
639668
) -> List[str]:
640669
"""Returns a list of restricted params that user chooses to override when creating an Aqua deployment.
641670
The default parameters coming from the container index json file cannot be overridden. In addition to this,
@@ -647,6 +676,8 @@ def _find_restricted_params(
647676
Inference container parameter string with default values.
648677
user_params:
649678
Inference container parameter string with user provided values.
679+
container_family: str
680+
The image family of model deployment container runtime.
650681
651682
Returns
652683
-------
@@ -659,7 +690,17 @@ def _find_restricted_params(
659690
user_params_dict = get_params_dict(user_params)
660691

661692
for key, items in user_params_dict.items():
662-
if key in default_params_dict or key in VLLMInferenceRestrictedParams:
693+
if (
694+
key in default_params_dict
695+
or (
696+
InferenceContainerType.CONTAINER_TYPE_VLLM in container_family
697+
and key in VLLMInferenceRestrictedParams
698+
)
699+
or (
700+
InferenceContainerType.CONTAINER_TYPE_TGI in container_family
701+
and key in TGIInferenceRestrictedParams
702+
)
703+
):
663704
restricted_params.append(key.lstrip("--"))
664705

665706
return restricted_params

tests/unitary/with_extras/aqua/test_deployment.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,14 @@ def test_get_deployment_default_params(
494494
"custom-container-key",
495495
["--max-model-len 4096", "--seed 42", "--trust-remote-code"],
496496
),
497+
(
498+
"odsc-vllm-serving",
499+
["--tensor-parallel-size 2"],
500+
),
501+
(
502+
"odsc-tgi-serving",
503+
["--port 8080"],
504+
),
497505
]
498506
)
499507
@patch("ads.model.datascience_model.DataScienceModel.from_id")
@@ -530,6 +538,55 @@ def test_validate_deployment_params(
530538
)
531539
assert result["valid"] is True
532540

541+
@parameterized.expand(
542+
[
543+
(
544+
"odsc-vllm-serving",
545+
["--max-model-len 4096"],
546+
),
547+
(
548+
"odsc-tgi-serving",
549+
["--max_stop_sequences 5"],
550+
),
551+
(
552+
"",
553+
["--some_random_key some_random_value"],
554+
),
555+
]
556+
)
557+
@patch("ads.model.datascience_model.DataScienceModel.from_id")
558+
@patch("ads.aqua.modeldeployment.deployment.get_container_config")
559+
def test_validate_deployment_params_for_unverified_models(
560+
self, container_type_key, params, mock_get_container_config, mock_from_id
561+
):
562+
"""Test to check if container family is used when metadata does not have image information
563+
for unverified models."""
564+
mock_model = MagicMock()
565+
mock_model.custom_metadata_list = ModelCustomMetadata()
566+
mock_from_id.return_value = mock_model
567+
568+
container_index_json = os.path.join(
569+
self.curr_dir, "test_data/ui/container_index.json"
570+
)
571+
with open(container_index_json, "r") as _file:
572+
container_index_config = json.load(_file)
573+
mock_get_container_config.return_value = container_index_config
574+
575+
if container_type_key in {"odsc-vllm-serving", "odsc-tgi-serving"} and params:
576+
result = self.app.validate_deployment_params(
577+
model_id="mock-model-id",
578+
params=params,
579+
container_family=container_type_key,
580+
)
581+
assert result["valid"] is True
582+
else:
583+
with pytest.raises(AquaValueError):
584+
self.app.validate_deployment_params(
585+
model_id="mock-model-id",
586+
params=params,
587+
container_family=container_type_key,
588+
)
589+
533590

534591
class TestMDInferenceResponse(unittest.TestCase):
535592
def setUp(self):

tests/unitary/with_extras/aqua/test_deployment_handler.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import unittest
99
from importlib import reload
1010
from unittest.mock import MagicMock, patch
11+
from parameterized import parameterized
1112

1213
from notebook.base.handlers import IPythonHandler
1314

@@ -124,6 +125,7 @@ def test_post(self, mock_create):
124125
server_port=None,
125126
health_check_port=None,
126127
env_var=None,
128+
container_family=None,
127129
)
128130

129131

@@ -156,21 +158,33 @@ def test_get_deployment_default_params(
156158
model_id="test_model_id", instance_shape=TestDataset.INSTANCE_SHAPE
157159
)
158160

161+
@parameterized.expand(
162+
[
163+
None,
164+
"container-family-name",
165+
]
166+
)
159167
@patch("notebook.base.handlers.APIHandler.finish")
160168
@patch("ads.aqua.modeldeployment.AquaDeploymentApp.validate_deployment_params")
161169
def test_validate_deployment_params(
162-
self, mock_validate_deployment_params, mock_finish
170+
self, container_family_value, mock_validate_deployment_params, mock_finish
163171
):
164172
mock_validate_deployment_params.return_value = dict(valid=True)
165173
mock_finish.side_effect = lambda x: x
166174

167175
self.test_instance.get_json_body = MagicMock(
168-
return_value=dict(model_id="test-model-id", params=self.default_params)
176+
return_value=dict(
177+
model_id="test-model-id",
178+
params=self.default_params,
179+
container_family=container_family_value,
180+
)
169181
)
170182
result = self.test_instance.post()
171183
assert result["valid"] is True
172184
mock_validate_deployment_params.assert_called_with(
173-
model_id="test-model-id", params=self.default_params
185+
model_id="test-model-id",
186+
params=self.default_params,
187+
container_family=container_family_value,
174188
)
175189

176190

0 commit comments

Comments
 (0)