Skip to content

Commit 6abad1e

Browse files
update enum name
1 parent 7d760ea commit 6abad1e

File tree

4 files changed

+27
-30
lines changed

4 files changed

+27
-30
lines changed

ads/aqua/common/enums.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,21 @@ class Tags(str, metaclass=ExtendedEnumMeta):
4040
AQUA_EVALUATION_MODEL_ID = "evaluation_model_id"
4141

4242

43+
class InferenceContainerType(str, metaclass=ExtendedEnumMeta):
44+
CONTAINER_TYPE_VLLM = "vllm"
45+
CONTAINER_TYPE_TGI = "tgi"
46+
47+
48+
class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
49+
AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
50+
AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
51+
52+
53+
class InferenceContainerParamType(str, metaclass=ExtendedEnumMeta):
54+
PARAM_TYPE_VLLM = "VLLM_PARAMS"
55+
PARAM_TYPE_TGI = "TGI_PARAMS"
56+
57+
4358
class HuggingFaceTags(str, metaclass=ExtendedEnumMeta):
4459
TEXT_GENERATION_INFERENCE = "text-generation-inference"
4560

ads/aqua/model/model.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
1515
from ads.aqua.app import AquaApp
16-
from ads.aqua.common.enums import Tags, HuggingFaceTags
16+
from ads.aqua.common.enums import Tags, HuggingFaceTags, InferenceContainerTypeFamily
1717
from ads.aqua.common.errors import AquaRuntimeError
1818
from ads.aqua.common.utils import (
1919
create_word_icon,
@@ -38,7 +38,6 @@
3838
)
3939
from ads.aqua.model.constants import *
4040
from ads.aqua.model.entities import *
41-
from ads.aqua.modeldeployment.enums import InferenceContainerTypeKey
4241
from ads.common.auth import default_signer
4342
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
4443
from ads.common.utils import get_console_link
@@ -628,11 +627,11 @@ def _create_model_catalog_entry(
628627

629628
if not inference_container:
630629
inference_container = (
631-
InferenceContainerTypeKey.AQUA_TGI_CONTAINER_KEY
630+
InferenceContainerTypeFamily.AQUA_TGI_CONTAINER_FAMILY
632631
if model_info
633632
and model_info.tags
634633
and HuggingFaceTags.TEXT_GENERATION_INFERENCE in model_info.tags
635-
else InferenceContainerTypeKey.AQUA_VLLM_CONTAINER_KEY
634+
else InferenceContainerTypeFamily.AQUA_VLLM_CONTAINER_FAMILY
636635
)
637636
logger.info(
638637
f"Model: {model_name} does not have associated inference container defaults. "

ads/aqua/modeldeployment/deployment.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010
from oci.data_science.models import ModelDeployment
1111

1212
from ads.aqua.app import AquaApp, logger
13-
from ads.aqua.common.enums import Tags
13+
from ads.aqua.common.enums import (
14+
Tags,
15+
InferenceContainerParamType,
16+
InferenceContainerType,
17+
InferenceContainerTypeFamily,
18+
)
1419
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1520
from ads.aqua.common.utils import (
1621
get_container_config,
@@ -42,11 +47,6 @@
4247
VLLMInferenceRestrictedParams,
4348
TGIInferenceRestrictedParams,
4449
)
45-
from ads.aqua.modeldeployment.enums import (
46-
InferenceContainerParamType,
47-
InferenceContainerType,
48-
InferenceContainerTypeKey,
49-
)
5050
from ads.common.object_storage_details import ObjectStorageDetails
5151
from ads.common.utils import get_log_links
5252
from ads.config import (
@@ -241,7 +241,7 @@ def create(
241241
if not container_family:
242242
raise AquaValueError(
243243
f"{message}. For unverified Aqua models, container_family parameter should be "
244-
f"set and value can be one of {', '.join(InferenceContainerTypeKey.values())}."
244+
f"set and value can be one of {', '.join(InferenceContainerTypeFamily.values())}."
245245
)
246246
container_type_key = container_family
247247
try:
@@ -572,7 +572,7 @@ def get_deployment_default_params(
572572

573573
if container_type_key:
574574
container_type_key = container_type_key.lower()
575-
if container_type_key in InferenceContainerTypeKey.values():
575+
if container_type_key in InferenceContainerTypeFamily.values():
576576
deployment_config = self.get_deployment_config(model_id)
577577
config_parameters = (
578578
deployment_config.get("configuration", UNKNOWN_DICT)
@@ -639,7 +639,7 @@ def validate_deployment_params(
639639
if not container_family:
640640
raise AquaValueError(
641641
f"{message}. For unverified Aqua models, container_family parameter should be "
642-
f"set and value can be one of {', '.join(InferenceContainerTypeKey.values())}."
642+
f"set and value can be one of {', '.join(InferenceContainerTypeFamily.values())}."
643643
)
644644
container_type_key = container_family
645645

ads/aqua/modeldeployment/enums.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,3 @@
33

44
# Copyright (c) 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6-
7-
from ads.common.extended_enum import ExtendedEnumMeta
8-
9-
10-
class InferenceContainerType(str, metaclass=ExtendedEnumMeta):
11-
CONTAINER_TYPE_VLLM = "vllm"
12-
CONTAINER_TYPE_TGI = "tgi"
13-
14-
15-
class InferenceContainerTypeKey(str, metaclass=ExtendedEnumMeta):
16-
AQUA_VLLM_CONTAINER_KEY = "odsc-vllm-serving"
17-
AQUA_TGI_CONTAINER_KEY = "odsc-tgi-serving"
18-
19-
20-
class InferenceContainerParamType(str, metaclass=ExtendedEnumMeta):
21-
PARAM_TYPE_VLLM = "VLLM_PARAMS"
22-
PARAM_TYPE_TGI = "TGI_PARAMS"

0 commit comments

Comments
 (0)