Skip to content

Commit e1fe284

Browse files
add container spec and update deployment
1 parent 8687883 commit e1fe284

File tree

10 files changed

+157
-30
lines changed

10 files changed

+157
-30
lines changed

ads/aqua/common/entities.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2024 Oracle and/or its affiliates.
3+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4+
5+
6+
class ContainerSpec:
7+
"""
8+
Class to hold to hold keys within the container spec.
9+
"""
10+
11+
CONTAINER_SPEC = "containerSpec"
12+
CLI_PARM = "cliParam"
13+
SERVER_PORT = "serverPort"
14+
HEALTH_CHECK_PORT = "healthCheckPort"
15+
ENV_VARS = "envVars"
16+
RESTRICTED_PARAMS = "restrictedParams"

ads/aqua/common/enums.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class Tags(str, metaclass=ExtendedEnumMeta):
4444
class InferenceContainerType(str, metaclass=ExtendedEnumMeta):
4545
CONTAINER_TYPE_VLLM = "vllm"
4646
CONTAINER_TYPE_TGI = "tgi"
47+
CONTAINER_TYPE_LLAMA_CPP = "llama-cpp"
4748

4849

4950
class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
@@ -55,6 +56,7 @@ class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
5556
class InferenceContainerParamType(str, metaclass=ExtendedEnumMeta):
5657
PARAM_TYPE_VLLM = "VLLM_PARAMS"
5758
PARAM_TYPE_TGI = "TGI_PARAMS"
59+
PARAM_TYPE_LLAMA_CPP = "LLAMA_CPP_PARAMS"
5860

5961

6062
class HuggingFaceTags(str, metaclass=ExtendedEnumMeta):

ads/aqua/common/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
from typing import List, Union
1818

1919
import fsspec
20+
import oci
2021
from cachetools import TTLCache, cached
22+
from oci.data_science.models import JobRun, Model
23+
from oci.object_storage.models import ObjectSummary
2124

22-
import oci
2325
from ads.aqua.common.enums import (
2426
InferenceContainerParamType,
2527
InferenceContainerType,
@@ -52,8 +54,6 @@
5254
from ads.common.utils import copy_file, get_console_link, upload_to_os
5355
from ads.config import AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, TENANCY_OCID
5456
from ads.model import DataScienceModel, ModelVersionSet
55-
from oci.data_science.models import JobRun, Model
56-
from oci.object_storage.models import ObjectSummary
5757

5858
logger = logging.getLogger("ads.aqua")
5959

@@ -909,6 +909,8 @@ def get_container_params_type(container_type_name: str) -> str:
909909
return InferenceContainerParamType.PARAM_TYPE_VLLM
910910
elif InferenceContainerType.CONTAINER_TYPE_TGI in container_type_name.lower():
911911
return InferenceContainerParamType.PARAM_TYPE_TGI
912+
elif InferenceContainerType.CONTAINER_TYPE_LLAMA_CPP in container_type_name.lower():
913+
return InferenceContainerParamType.PARAM_TYPE_LLAMA_CPP
912914
else:
913915
return UNKNOWN
914916

ads/aqua/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,7 @@
7474
"--sharded",
7575
"--trust-remote-code",
7676
}
77+
LLAMA_CPP_INFERENCE_RESTRICTED_PARAMS = {
78+
"--port",
79+
"--host",
80+
}

ads/aqua/modeldeployment/deployment.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Dict, List, Optional, Union
77

88
from ads.aqua.app import AquaApp, logger
9+
from ads.aqua.common.entities import ContainerSpec
910
from ads.aqua.common.enums import (
1011
InferenceContainerTypeFamily,
1112
Tags,
@@ -38,7 +39,6 @@
3839
from ads.aqua.modeldeployment.entities import (
3940
AquaDeployment,
4041
AquaDeploymentDetail,
41-
ContainerSpec,
4242
)
4343
from ads.aqua.ui import ModelFormat
4444
from ads.common.object_storage_details import ObjectStorageDetails
@@ -281,8 +281,10 @@ def create(
281281
f"Aqua Image used for deploying {aqua_model.id} : {container_image}"
282282
)
283283

284+
# todo: use AquaContainerConfig.from_container_index_json instead.
284285
# Fetch the startup cli command for the container
285-
# container_index.json will have "containerSpec" section which will provide the cli params for a given container family
286+
# container_index.json will have "containerSpec" section which will provide the cli params for
287+
# a given container family
286288
container_config = get_container_config()
287289
container_spec = container_config.get(ContainerSpec.CONTAINER_SPEC, {}).get(
288290
container_type_key, {}
@@ -308,6 +310,18 @@ def create(
308310
# validate user provided params
309311
user_params = env_var.get("PARAMS", UNKNOWN)
310312
if user_params:
313+
# todo: remove this check in the future version, logic to be moved to container_index
314+
if (
315+
container_type_key.lower()
316+
== InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
317+
):
318+
# AQUA_LLAMA_CPP_CONTAINER_FAMILY container uses uvicorn that required model/server params
319+
# to be set as env vars
320+
raise AquaValueError(
321+
f"Currently, parameters cannot be overridden for the container: {container_image}. Please proceed "
322+
f"with deployment without parameter overrides."
323+
)
324+
311325
restricted_params = self._find_restricted_params(
312326
params, user_params, container_type_key
313327
)

ads/aqua/modeldeployment/entities.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

@@ -24,18 +23,6 @@ class ModelParams:
2423
model: str = None
2524

2625

27-
class ContainerSpec:
28-
"""
29-
Class to hold to hold keys within the container spec.
30-
"""
31-
32-
CONTAINER_SPEC = "containerSpec"
33-
CLI_PARM = "cliParam"
34-
SERVER_PORT = "serverPort"
35-
HEALTH_CHECK_PORT = "healthCheckPort"
36-
ENV_VARS = "envVars"
37-
38-
3926
@dataclass
4027
class ShapeInfo:
4128
instance_shape: str = None
@@ -83,15 +70,11 @@ def from_oci_model_deployment(
8370
AquaDeployment:
8471
The instance of the Aqua model deployment.
8572
"""
86-
instance_configuration = (
87-
oci_model_deployment.model_deployment_configuration_details.model_configuration_details.instance_configuration
88-
)
73+
instance_configuration = oci_model_deployment.model_deployment_configuration_details.model_configuration_details.instance_configuration
8974
instance_shape_config_details = (
9075
instance_configuration.model_deployment_instance_shape_config_details
9176
)
92-
instance_count = (
93-
oci_model_deployment.model_deployment_configuration_details.model_configuration_details.scaling_policy.instance_count
94-
)
77+
instance_count = oci_model_deployment.model_deployment_configuration_details.model_configuration_details.scaling_policy.instance_count
9578
shape_info = ShapeInfo(
9679
instance_shape=instance_configuration.instance_shape_name,
9780
instance_count=instance_count,

ads/aqua/ui.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99
from typing import Dict, List, Optional
1010

1111
from cachetools import TTLCache
12+
from oci.exceptions import ServiceError
13+
from oci.identity.models import Compartment
1214

1315
from ads.aqua import logger
1416
from ads.aqua.app import AquaApp
17+
from ads.aqua.common.entities import ContainerSpec
1518
from ads.aqua.common.enums import Tags
1619
from ads.aqua.common.errors import AquaResourceAccessError, AquaValueError
1720
from ads.aqua.common.utils import get_container_config, load_config, sanitize_response
@@ -27,8 +30,6 @@
2730
TENANCY_OCID,
2831
)
2932
from ads.telemetry import telemetry
30-
from oci.exceptions import ServiceError
31-
from oci.identity.models import Compartment
3233

3334

3435
class ModelFormat(Enum):
@@ -40,6 +41,19 @@ def to_dict(self):
4041
return self.value
4142

4243

44+
# todo: the container config spec information is shared across ui and deployment modules, move them
45+
# within ads.aqua.common.entities. In that case, check for circular imports due to usage of get_container_config.
46+
47+
48+
@dataclass(repr=False)
49+
class AquaContainerConfigSpec(DataClassSerializable):
50+
cli_param: str = None
51+
server_port: str = None
52+
health_check_port: str = None
53+
env_vars: List[dict] = None
54+
restricted_params: List[str] = None
55+
56+
4357
@dataclass(repr=False)
4458
class AquaContainerConfigItem(DataClassSerializable):
4559
"""Represents an item of the AQUA container configuration."""
@@ -60,6 +74,7 @@ def __repr__(self):
6074
family: str = None
6175
platforms: List[Platform] = None
6276
model_formats: List[ModelFormat] = None
77+
spec: AquaContainerConfigSpec = field(default_factory=AquaContainerConfigSpec)
6378

6479

6580
@dataclass(repr=False)
@@ -81,7 +96,9 @@ def to_dict(self):
8196

8297
@classmethod
8398
def from_container_index_json(
84-
cls, config: Optional[Dict] = None
99+
cls,
100+
config: Optional[Dict] = None,
101+
enable_spec: Optional[bool] = False,
85102
) -> "AquaContainerConfig":
86103
"""
87104
Create an AquaContainerConfig instance from a container index JSON.
@@ -90,6 +107,8 @@ def from_container_index_json(
90107
----------
91108
config : Dict
92109
The container index JSON.
110+
enable_spec: bool
111+
flag to check if container specification details should be fetched.
93112
94113
Returns
95114
-------
@@ -114,6 +133,13 @@ def from_container_index_json(
114133
ModelFormat[model_format]
115134
for model_format in container.get("modelFormats", [])
116135
]
136+
container_spec = (
137+
config.get(ContainerSpec.CONTAINER_SPEC, {}).get(
138+
container_type, {}
139+
)
140+
if enable_spec
141+
else None
142+
)
117143
container_item = AquaContainerConfigItem(
118144
name=container.get("name", ""),
119145
version=container.get("version", ""),
@@ -123,6 +149,21 @@ def from_container_index_json(
123149
family=container_type,
124150
platforms=platforms,
125151
model_formats=model_formats,
152+
spec=AquaContainerConfigSpec(
153+
cli_param=container_spec.get(ContainerSpec.CLI_PARM, ""),
154+
server_port=container_spec.get(
155+
ContainerSpec.SERVER_PORT, ""
156+
),
157+
health_check_port=container_spec.get(
158+
ContainerSpec.HEALTH_CHECK_PORT, ""
159+
),
160+
env_vars=container_spec.get(ContainerSpec.ENV_VARS, []),
161+
restricted_params=container_spec.get(
162+
ContainerSpec.RESTRICTED_PARAMS, []
163+
),
164+
)
165+
if container_spec
166+
else None,
126167
)
127168
if container.get("type") == "inference":
128169
inference_items[container_type] = container_item
@@ -571,5 +612,6 @@ def list_containers(self) -> AquaContainerConfig:
571612
The AQUA containers configurations.
572613
"""
573614
return AquaContainerConfig.from_container_index_json(
574-
config=get_container_config()
615+
config=get_container_config(),
616+
enable_spec=True,
575617
)

tests/unitary/with_extras/aqua/test_data/ui/container_index.json

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
}
2121
],
2222
"healthCheckPort": "8080",
23+
"restrictedParams": [],
2324
"serverPort": "8080"
2425
},
2526
"odsc-tgi-serving": {
@@ -39,10 +40,17 @@
3940
}
4041
],
4142
"healthCheckPort": "8080",
43+
"restrictedParams": [
44+
"--port",
45+
"--hostname",
46+
"--num-shard",
47+
"--sharded",
48+
"--trust-remote-code"
49+
],
4250
"serverPort": "8080"
4351
},
4452
"odsc-vllm-serving": {
45-
"cliParam": "--served-model-name $(python -c 'import os; print(os.environ.get(\"ODSC_SERVED_MODEL_NAME\",\"odsc-llm\"))') --seed 42 ",
53+
"cliParam": "--served-model-name odsc-llm --seed 42 ",
4654
"envVars": [
4755
{
4856
"MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/completions"
@@ -58,6 +66,12 @@
5866
}
5967
],
6068
"healthCheckPort": "8080",
69+
"restrictedParams": [
70+
"--port",
71+
"--host",
72+
"--served-model-name",
73+
"--seed"
74+
],
6175
"serverPort": "8080"
6276
}
6377
},

tests/unitary/with_extras/aqua/test_decorator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def setUp(self, ipython_init_mock) -> None:
126126
"status": 500,
127127
"message": "Internal Server Error",
128128
"service_payload": {},
129-
"reason": f"MultipartUploadError: MultipartUploadError exception has occurred. {UPLOAD_MANAGER_DEBUG_INFORMATION_LOG}",
129+
"reason": f"MultipartUploadError: MultipartUploadError exception has occured. {UPLOAD_MANAGER_DEBUG_INFORMATION_LOG}",
130130
"request_id": TestDataset.mock_request_id,
131131
},
132132
],

0 commit comments

Comments
 (0)