Skip to content

Commit ca3c278

Browse files
[ODSC-60107] Return container spec with list containers API (#906)
2 parents 8687883 + 53bff8d commit ca3c278

File tree

10 files changed

+180
-30
lines changed

10 files changed

+180
-30
lines changed

ads/aqua/common/entities.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2024 Oracle and/or its affiliates.
3+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4+
5+
6+
class ContainerSpec:
7+
"""
8+
Class to hold to hold keys within the container spec.
9+
"""
10+
11+
CONTAINER_SPEC = "containerSpec"
12+
CLI_PARM = "cliParam"
13+
SERVER_PORT = "serverPort"
14+
HEALTH_CHECK_PORT = "healthCheckPort"
15+
ENV_VARS = "envVars"
16+
RESTRICTED_PARAMS = "restrictedParams"

ads/aqua/common/enums.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class Tags(str, metaclass=ExtendedEnumMeta):
4444
class InferenceContainerType(str, metaclass=ExtendedEnumMeta):
4545
CONTAINER_TYPE_VLLM = "vllm"
4646
CONTAINER_TYPE_TGI = "tgi"
47+
CONTAINER_TYPE_LLAMA_CPP = "llama-cpp"
4748

4849

4950
class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
@@ -55,6 +56,7 @@ class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
5556
class InferenceContainerParamType(str, metaclass=ExtendedEnumMeta):
5657
PARAM_TYPE_VLLM = "VLLM_PARAMS"
5758
PARAM_TYPE_TGI = "TGI_PARAMS"
59+
PARAM_TYPE_LLAMA_CPP = "LLAMA_CPP_PARAMS"
5860

5961

6062
class HuggingFaceTags(str, metaclass=ExtendedEnumMeta):

ads/aqua/common/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
from typing import List, Union
1818

1919
import fsspec
20+
import oci
2021
from cachetools import TTLCache, cached
22+
from oci.data_science.models import JobRun, Model
23+
from oci.object_storage.models import ObjectSummary
2124

22-
import oci
2325
from ads.aqua.common.enums import (
2426
InferenceContainerParamType,
2527
InferenceContainerType,
@@ -52,8 +54,6 @@
5254
from ads.common.utils import copy_file, get_console_link, upload_to_os
5355
from ads.config import AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, TENANCY_OCID
5456
from ads.model import DataScienceModel, ModelVersionSet
55-
from oci.data_science.models import JobRun, Model
56-
from oci.object_storage.models import ObjectSummary
5757

5858
logger = logging.getLogger("ads.aqua")
5959

@@ -909,6 +909,8 @@ def get_container_params_type(container_type_name: str) -> str:
909909
return InferenceContainerParamType.PARAM_TYPE_VLLM
910910
elif InferenceContainerType.CONTAINER_TYPE_TGI in container_type_name.lower():
911911
return InferenceContainerParamType.PARAM_TYPE_TGI
912+
elif InferenceContainerType.CONTAINER_TYPE_LLAMA_CPP in container_type_name.lower():
913+
return InferenceContainerParamType.PARAM_TYPE_LLAMA_CPP
912914
else:
913915
return UNKNOWN
914916

ads/aqua/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,7 @@
7474
"--sharded",
7575
"--trust-remote-code",
7676
}
77+
LLAMA_CPP_INFERENCE_RESTRICTED_PARAMS = {
78+
"--port",
79+
"--host",
80+
}

ads/aqua/modeldeployment/deployment.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Dict, List, Optional, Union
77

88
from ads.aqua.app import AquaApp, logger
9+
from ads.aqua.common.entities import ContainerSpec
910
from ads.aqua.common.enums import (
1011
InferenceContainerTypeFamily,
1112
Tags,
@@ -38,7 +39,6 @@
3839
from ads.aqua.modeldeployment.entities import (
3940
AquaDeployment,
4041
AquaDeploymentDetail,
41-
ContainerSpec,
4242
)
4343
from ads.aqua.ui import ModelFormat
4444
from ads.common.object_storage_details import ObjectStorageDetails
@@ -281,8 +281,10 @@ def create(
281281
f"Aqua Image used for deploying {aqua_model.id} : {container_image}"
282282
)
283283

284+
# todo: use AquaContainerConfig.from_container_index_json instead.
284285
# Fetch the startup cli command for the container
285-
# container_index.json will have "containerSpec" section which will provide the cli params for a given container family
286+
# container_index.json will have "containerSpec" section which will provide the cli params for
287+
# a given container family
286288
container_config = get_container_config()
287289
container_spec = container_config.get(ContainerSpec.CONTAINER_SPEC, {}).get(
288290
container_type_key, {}
@@ -308,6 +310,18 @@ def create(
308310
# validate user provided params
309311
user_params = env_var.get("PARAMS", UNKNOWN)
310312
if user_params:
313+
# todo: remove this check in the future version, logic to be moved to container_index
314+
if (
315+
container_type_key.lower()
316+
== InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
317+
):
318+
# AQUA_LLAMA_CPP_CONTAINER_FAMILY container uses uvicorn that required model/server params
319+
# to be set as env vars
320+
raise AquaValueError(
321+
f"Currently, parameters cannot be overridden for the container: {container_image}. Please proceed "
322+
f"with deployment without parameter overrides."
323+
)
324+
311325
restricted_params = self._find_restricted_params(
312326
params, user_params, container_type_key
313327
)

ads/aqua/modeldeployment/entities.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

65
from dataclasses import dataclass, field
76
from typing import Union
87

9-
from oci.data_science.models import ModelDeployment, ModelDeploymentSummary
8+
from oci.data_science.models import (
9+
ModelDeployment,
10+
ModelDeploymentSummary,
11+
)
1012

1113
from ads.aqua.common.enums import Tags
1214
from ads.aqua.constants import UNKNOWN, UNKNOWN_DICT
@@ -24,18 +26,6 @@ class ModelParams:
2426
model: str = None
2527

2628

27-
class ContainerSpec:
28-
"""
29-
Class to hold to hold keys within the container spec.
30-
"""
31-
32-
CONTAINER_SPEC = "containerSpec"
33-
CLI_PARM = "cliParam"
34-
SERVER_PORT = "serverPort"
35-
HEALTH_CHECK_PORT = "healthCheckPort"
36-
ENV_VARS = "envVars"
37-
38-
3929
@dataclass
4030
class ShapeInfo:
4131
instance_shape: str = None
@@ -61,6 +51,7 @@ class AquaDeployment(DataClassSerializable):
6151
lifecycle_details: str = None
6252
shape_info: field(default_factory=ShapeInfo) = None
6353
tags: dict = None
54+
environment_variables: dict = None
6455

6556
@classmethod
6657
def from_oci_model_deployment(
@@ -83,15 +74,12 @@ def from_oci_model_deployment(
8374
AquaDeployment:
8475
The instance of the Aqua model deployment.
8576
"""
86-
instance_configuration = (
87-
oci_model_deployment.model_deployment_configuration_details.model_configuration_details.instance_configuration
88-
)
77+
instance_configuration = oci_model_deployment.model_deployment_configuration_details.model_configuration_details.instance_configuration
8978
instance_shape_config_details = (
9079
instance_configuration.model_deployment_instance_shape_config_details
9180
)
92-
instance_count = (
93-
oci_model_deployment.model_deployment_configuration_details.model_configuration_details.scaling_policy.instance_count
94-
)
81+
instance_count = oci_model_deployment.model_deployment_configuration_details.model_configuration_details.scaling_policy.instance_count
82+
environment_variables = oci_model_deployment.model_deployment_configuration_details.environment_configuration_details.environment_variables
9583
shape_info = ShapeInfo(
9684
instance_shape=instance_configuration.instance_shape_name,
9785
instance_count=instance_count,
@@ -131,6 +119,7 @@ def from_oci_model_deployment(
131119
region=region,
132120
),
133121
tags=freeform_tags,
122+
environment_variables=environment_variables,
134123
)
135124

136125

ads/aqua/ui.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99
from typing import Dict, List, Optional
1010

1111
from cachetools import TTLCache
12+
from oci.exceptions import ServiceError
13+
from oci.identity.models import Compartment
1214

1315
from ads.aqua import logger
1416
from ads.aqua.app import AquaApp
17+
from ads.aqua.common.entities import ContainerSpec
1518
from ads.aqua.common.enums import Tags
1619
from ads.aqua.common.errors import AquaResourceAccessError, AquaValueError
1720
from ads.aqua.common.utils import get_container_config, load_config, sanitize_response
@@ -27,8 +30,6 @@
2730
TENANCY_OCID,
2831
)
2932
from ads.telemetry import telemetry
30-
from oci.exceptions import ServiceError
31-
from oci.identity.models import Compartment
3233

3334

3435
class ModelFormat(Enum):
@@ -40,6 +41,19 @@ def to_dict(self):
4041
return self.value
4142

4243

44+
# todo: the container config spec information is shared across ui and deployment modules, move them
45+
# within ads.aqua.common.entities. In that case, check for circular imports due to usage of get_container_config.
46+
47+
48+
@dataclass(repr=False)
49+
class AquaContainerConfigSpec(DataClassSerializable):
50+
cli_param: str = None
51+
server_port: str = None
52+
health_check_port: str = None
53+
env_vars: List[dict] = None
54+
restricted_params: List[str] = None
55+
56+
4357
@dataclass(repr=False)
4458
class AquaContainerConfigItem(DataClassSerializable):
4559
"""Represents an item of the AQUA container configuration."""
@@ -60,6 +74,7 @@ def __repr__(self):
6074
family: str = None
6175
platforms: List[Platform] = None
6276
model_formats: List[ModelFormat] = None
77+
spec: AquaContainerConfigSpec = field(default_factory=AquaContainerConfigSpec)
6378

6479

6580
@dataclass(repr=False)
@@ -81,7 +96,9 @@ def to_dict(self):
8196

8297
@classmethod
8398
def from_container_index_json(
84-
cls, config: Optional[Dict] = None
99+
cls,
100+
config: Optional[Dict] = None,
101+
enable_spec: Optional[bool] = False,
85102
) -> "AquaContainerConfig":
86103
"""
87104
Create an AquaContainerConfig instance from a container index JSON.
@@ -90,6 +107,8 @@ def from_container_index_json(
90107
----------
91108
config : Dict
92109
The container index JSON.
110+
enable_spec: bool
111+
flag to check if container specification details should be fetched.
93112
94113
Returns
95114
-------
@@ -114,6 +133,13 @@ def from_container_index_json(
114133
ModelFormat[model_format]
115134
for model_format in container.get("modelFormats", [])
116135
]
136+
container_spec = (
137+
config.get(ContainerSpec.CONTAINER_SPEC, {}).get(
138+
container_type, {}
139+
)
140+
if enable_spec
141+
else None
142+
)
117143
container_item = AquaContainerConfigItem(
118144
name=container.get("name", ""),
119145
version=container.get("version", ""),
@@ -123,6 +149,21 @@ def from_container_index_json(
123149
family=container_type,
124150
platforms=platforms,
125151
model_formats=model_formats,
152+
spec=AquaContainerConfigSpec(
153+
cli_param=container_spec.get(ContainerSpec.CLI_PARM, ""),
154+
server_port=container_spec.get(
155+
ContainerSpec.SERVER_PORT, ""
156+
),
157+
health_check_port=container_spec.get(
158+
ContainerSpec.HEALTH_CHECK_PORT, ""
159+
),
160+
env_vars=container_spec.get(ContainerSpec.ENV_VARS, []),
161+
restricted_params=container_spec.get(
162+
ContainerSpec.RESTRICTED_PARAMS, []
163+
),
164+
)
165+
if container_spec
166+
else None,
126167
)
127168
if container.get("type") == "inference":
128169
inference_items[container_type] = container_item
@@ -571,5 +612,6 @@ def list_containers(self) -> AquaContainerConfig:
571612
The AQUA containers configurations.
572613
"""
573614
return AquaContainerConfig.from_container_index_json(
574-
config=get_container_config()
615+
config=get_container_config(),
616+
enable_spec=True,
575617
)

tests/unitary/with_extras/aqua/test_data/ui/container_index.json

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
}
2121
],
2222
"healthCheckPort": "8080",
23+
"restrictedParams": [],
2324
"serverPort": "8080"
2425
},
2526
"odsc-tgi-serving": {
@@ -39,10 +40,17 @@
3940
}
4041
],
4142
"healthCheckPort": "8080",
43+
"restrictedParams": [
44+
"--port",
45+
"--hostname",
46+
"--num-shard",
47+
"--sharded",
48+
"--trust-remote-code"
49+
],
4250
"serverPort": "8080"
4351
},
4452
"odsc-vllm-serving": {
45-
"cliParam": "--served-model-name $(python -c 'import os; print(os.environ.get(\"ODSC_SERVED_MODEL_NAME\",\"odsc-llm\"))') --seed 42 ",
53+
"cliParam": "--served-model-name odsc-llm --seed 42 ",
4654
"envVars": [
4755
{
4856
"MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/completions"
@@ -58,6 +66,12 @@
5866
}
5967
],
6068
"healthCheckPort": "8080",
69+
"restrictedParams": [
70+
"--port",
71+
"--host",
72+
"--served-model-name",
73+
"--seed"
74+
],
6175
"serverPort": "8080"
6276
}
6377
},

tests/unitary/with_extras/aqua/test_deployment.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,12 @@ class TestDataset:
181181
"created_on": "2024-01-01T00:00:00.000000+00:00",
182182
"created_by": "ocid1.user.oc1..<OCID>",
183183
"endpoint": MODEL_DEPLOYMENT_URL,
184+
"environment_variables": {
185+
"BASE_MODEL": "service_models/model-name/artifact",
186+
"MODEL_DEPLOY_ENABLE_STREAMING": "true",
187+
"MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/completions",
188+
"PARAMS": "--served-model-name odsc-llm --seed 42",
189+
},
184190
"console_link": "https://cloud.oracle.com/data-science/model-deployments/ocid1.datasciencemodeldeployment.oc1.<region>.<MD_OCID>?region=region-name",
185191
"lifecycle_details": "",
186192
"shape_info": {
@@ -192,6 +198,14 @@ class TestDataset:
192198
"tags": {"OCI_AQUA": "active", "aqua_model_name": "model-name"},
193199
}
194200

201+
aqua_deployment_gguf_env_vars = {
202+
"BASE_MODEL": "service_models/model-name/artifact",
203+
"BASE_MODEL_FILE": "model-name.gguf",
204+
"MODEL_DEPLOY_ENABLE_STREAMING": "true",
205+
"MODEL_DEPLOY_HEALTH_ENDPOINT": "/v1/models",
206+
"MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/completions",
207+
}
208+
195209
aqua_deployment_gguf_shape_info = {
196210
"instance_shape": DEPLOYMENT_SHAPE_NAME_CPU,
197211
"instance_count": 1,
@@ -544,6 +558,9 @@ def test_create_deployment_for_gguf_model(
544558
expected_result = copy.deepcopy(TestDataset.aqua_deployment_object)
545559
expected_result["state"] = "CREATING"
546560
expected_result["shape_info"] = TestDataset.aqua_deployment_gguf_shape_info
561+
expected_result["environment_variables"] = (
562+
TestDataset.aqua_deployment_gguf_env_vars
563+
)
547564
assert actual_attributes == expected_result
548565

549566
@parameterized.expand(

0 commit comments

Comments
 (0)