Skip to content

Commit d252d6d

Browse files
committed
Refactor Container Index JSON Loader. Move to Config Package & Switch to Pydantic
1 parent 71efc26 commit d252d6d

File tree

12 files changed

+274
-284
lines changed

12 files changed

+274
-284
lines changed

ads/aqua/common/enums.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,6 @@
22
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

5-
"""
6-
aqua.common.enums
7-
~~~~~~~~~~~~~~
8-
This module contains the set of enums used in AQUA.
9-
"""
10-
115
from ads.common.extended_enum import ExtendedEnum
126

137

@@ -88,7 +82,8 @@ class RqsAdditionalDetails(ExtendedEnum):
8882

8983
class TextEmbeddingInferenceContainerParams(ExtendedEnum):
9084
"""Contains a subset of params that are required for enabling model deployment in OCI Data Science. More options
91-
are available at https://huggingface.co/docs/text-embeddings-inference/en/cli_arguments"""
85+
are available at https://huggingface.co/docs/text-embeddings-inference/en/cli_arguments
86+
"""
9287

9388
MODEL_ID = "model-id"
9489
PORT = "port"
@@ -97,3 +92,14 @@ class TextEmbeddingInferenceContainerParams(ExtendedEnum):
9792
class ConfigFolder(ExtendedEnum):
9893
CONFIG = "config"
9994
ARTIFACT = "artifact"
95+
96+
97+
class ModelFormat(ExtendedEnum):
98+
GGUF = "GGUF"
99+
SAFETENSORS = "SAFETENSORS"
100+
UNKNOWN = "UNKNOWN"
101+
102+
103+
class Platform(ExtendedEnum):
104+
ARM_CPU = "ARM_CPU"
105+
NVIDIA_GPU = "NVIDIA_GPU"

ads/aqua/common/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def service_config_path():
553553
return f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config"
554554

555555

556-
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
556+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=1), timer=datetime.now))
557557
def get_container_config():
558558
config = load_config(
559559
file_path=service_config_path(),

ads/aqua/config/container_config.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2025 Oracle and/or its affiliates.
3+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4+
5+
from typing import Dict, List, Optional
6+
7+
from pydantic import Field
8+
9+
from ads.aqua.common.entities import ContainerSpec
10+
from ads.aqua.config.utils.serializer import Serializable
11+
12+
13+
class AquaContainerConfigSpec(Serializable):
14+
"""
15+
Represents container specification details.
16+
17+
Attributes
18+
----------
19+
cli_param (Optional[str]): CLI parameter for container configuration.
20+
server_port (Optional[str]): The server port for the container.
21+
health_check_port (Optional[str]): The health check port for the container.
22+
env_vars (Optional[List[Dict]]): Environment variables for the container.
23+
restricted_params (Optional[List[str]]): Restricted parameters for container configuration.
24+
"""
25+
26+
cli_param: Optional[str] = Field(
27+
default=None, description="CLI parameter for container configuration."
28+
)
29+
server_port: Optional[str] = Field(
30+
default=None, description="Server port for the container."
31+
)
32+
health_check_port: Optional[str] = Field(
33+
default=None, description="Health check port for the container."
34+
)
35+
env_vars: Optional[List[Dict]] = Field(
36+
default_factory=list, description="List of environment variables."
37+
)
38+
restricted_params: Optional[List[str]] = Field(
39+
default_factory=list, description="List of restricted parameters."
40+
)
41+
42+
class Config:
43+
extra = "allow"
44+
45+
46+
class AquaContainerConfigItem(Serializable):
47+
"""
48+
Represents an item of the AQUA container configuration.
49+
50+
Attributes
51+
----------
52+
name (Optional[str]): Name of the container configuration item.
53+
version (Optional[str]): Version of the container.
54+
display_name (Optional[str]): Display name for UI.
55+
family (Optional[str]): Container family or category.
56+
platforms (Optional[List[str]]): Supported platforms.
57+
model_formats (Optional[List[str]]): Supported model formats.
58+
spec (Optional[AquaContainerConfigSpec]): Container specification details.
59+
"""
60+
61+
name: Optional[str] = Field(
62+
default=None, description="Name of the container configuration item."
63+
)
64+
version: Optional[str] = Field(
65+
default=None, description="Version of the container."
66+
)
67+
display_name: Optional[str] = Field(
68+
default=None, description="Display name of the container."
69+
)
70+
family: Optional[str] = Field(
71+
default=None, description="Container family or category."
72+
)
73+
platforms: Optional[List[str]] = Field(
74+
default_factory=list, description="Supported platforms."
75+
)
76+
model_formats: Optional[List[str]] = Field(
77+
default_factory=list, description="Supported model formats."
78+
)
79+
spec: Optional[AquaContainerConfigSpec] = Field(
80+
default_factory=AquaContainerConfigSpec,
81+
description="Detailed container specification.",
82+
)
83+
84+
class Config:
85+
extra = "allow"
86+
87+
88+
class AquaContainerConfig(Serializable):
89+
"""
90+
Represents a configuration of AQUA containers to be returned to the client.
91+
92+
Attributes
93+
----------
94+
inference (Dict[str, AquaContainerConfigItem]): Inference container configuration items.
95+
finetune (Dict[str, AquaContainerConfigItem]): Fine-tuning container configuration items.
96+
evaluate (Dict[str, AquaContainerConfigItem]): Evaluation container configuration items.
97+
"""
98+
99+
inference: Dict[str, AquaContainerConfigItem] = Field(
100+
default_factory=dict, description="Inference container configuration items."
101+
)
102+
finetune: Dict[str, AquaContainerConfigItem] = Field(
103+
default_factory=dict, description="Fine-tuning container configuration items."
104+
)
105+
evaluate: Dict[str, AquaContainerConfigItem] = Field(
106+
default_factory=dict, description="Evaluation container configuration items."
107+
)
108+
109+
def to_dict(self):
110+
return {
111+
"inference": list(self.inference.values()),
112+
"finetune": list(self.finetune.values()),
113+
"evaluate": list(self.evaluate.values()),
114+
}
115+
116+
@classmethod
117+
def from_container_index_json(
118+
cls,
119+
config: Optional[Dict] = None,
120+
enable_spec: Optional[bool] = False,
121+
) -> "AquaContainerConfig":
122+
"""
123+
Creates an AquaContainerConfig instance from a container index JSON.
124+
125+
Parameters
126+
----------
127+
config (Optional[Dict]): The container index JSON.
128+
enable_spec (Optional[bool]): If True, fetch container specification details.
129+
130+
Returns
131+
-------
132+
AquaContainerConfig: The constructed container configuration.
133+
"""
134+
# if not config:
135+
# config = get_container_config()
136+
137+
inference_items: Dict[str, AquaContainerConfigItem] = {}
138+
finetune_items: Dict[str, AquaContainerConfigItem] = {}
139+
evaluate_items: Dict[str, AquaContainerConfigItem] = {}
140+
141+
for container_type, containers in config.items():
142+
if isinstance(containers, list):
143+
for container in containers:
144+
platforms = container.get("platforms", [])
145+
model_formats = container.get("modelFormats", [])
146+
container_spec = (
147+
config.get(ContainerSpec.CONTAINER_SPEC, {}).get(
148+
container_type, {}
149+
)
150+
if enable_spec
151+
else None
152+
)
153+
container_item = AquaContainerConfigItem(
154+
name=container.get("name", ""),
155+
version=container.get("version", ""),
156+
display_name=container.get(
157+
"displayName", container.get("version", "")
158+
),
159+
family=container_type,
160+
platforms=platforms,
161+
model_formats=model_formats,
162+
spec=(
163+
AquaContainerConfigSpec(
164+
cli_param=container_spec.get(
165+
ContainerSpec.CLI_PARM, ""
166+
),
167+
server_port=container_spec.get(
168+
ContainerSpec.SERVER_PORT, ""
169+
),
170+
health_check_port=container_spec.get(
171+
ContainerSpec.HEALTH_CHECK_PORT, ""
172+
),
173+
env_vars=container_spec.get(ContainerSpec.ENV_VARS, []),
174+
restricted_params=container_spec.get(
175+
ContainerSpec.RESTRICTED_PARAMS, []
176+
),
177+
)
178+
if container_spec
179+
else None
180+
),
181+
)
182+
if container.get("type") == "inference":
183+
inference_items[container_type] = container_item
184+
elif (
185+
container.get("type") == "fine-tune"
186+
or container_type == "odsc-llm-fine-tuning"
187+
):
188+
finetune_items[container_type] = container_item
189+
elif (
190+
container.get("type") == "evaluate"
191+
or container_type == "odsc-llm-evaluate"
192+
):
193+
evaluate_items[container_type] = container_item
194+
195+
return cls(
196+
inference=inference_items, finetune=finetune_items, evaluate=evaluate_items
197+
)

ads/aqua/evaluation/evaluation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,13 @@
4040
from ads.aqua.common.utils import (
4141
extract_id_and_name_from_tag,
4242
fire_and_forget,
43+
get_container_config,
4344
get_container_image,
4445
is_valid_ocid,
4546
upload_local_to_os,
4647
)
4748
from ads.aqua.config.config import get_evaluation_service_config
49+
from ads.aqua.config.container_config import AquaContainerConfig
4850
from ads.aqua.constants import (
4951
CONSOLE_LINK_RESOURCE_TYPE_MAPPING,
5052
EVALUATION_REPORT,
@@ -75,7 +77,6 @@
7577
CreateAquaEvaluationDetails,
7678
)
7779
from ads.aqua.evaluation.errors import EVALUATION_JOB_EXIT_CODE_MESSAGE
78-
from ads.aqua.ui import AquaContainerConfig
7980
from ads.common.auth import default_signer
8081
from ads.common.object_storage_details import ObjectStorageDetails
8182
from ads.common.utils import get_console_link, get_files, get_log_links
@@ -192,7 +193,7 @@ def create(
192193
evaluation_source.runtime.to_dict()
193194
)
194195
inference_config = AquaContainerConfig.from_container_index_json(
195-
enable_spec=True
196+
config=get_container_config(), enable_spec=True
196197
).inference
197198
for container in inference_config.values():
198199
if container.name == runtime.image[: runtime.image.rfind(":")]:

ads/aqua/extension/model_handler.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,13 @@
88
from tornado.web import HTTPError
99

1010
from ads.aqua.common.decorator import handle_exceptions
11-
from ads.aqua.common.enums import (
12-
CustomInferenceContainerTypeFamily,
13-
)
11+
from ads.aqua.common.enums import CustomInferenceContainerTypeFamily
1412
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
15-
from ads.aqua.common.utils import (
16-
get_hf_model_info,
17-
is_valid_ocid,
18-
list_hf_models,
19-
)
13+
from ads.aqua.common.utils import get_hf_model_info, is_valid_ocid, list_hf_models
2014
from ads.aqua.extension.base_handler import AquaAPIhandler
2115
from ads.aqua.extension.errors import Errors
2216
from ads.aqua.model import AquaModelApp
2317
from ads.aqua.model.entities import AquaModelSummary, HFModelSummary
24-
from ads.aqua.ui import ModelFormat
2518

2619

2720
class AquaModelHandler(AquaAPIhandler):
@@ -45,7 +38,7 @@ def get(
4538
400, Errors.MISSING_REQUIRED_PARAMETER.format("model_format")
4639
)
4740
try:
48-
model_format = ModelFormat(model_format.upper())
41+
model_format = model_format.upper()
4942
except ValueError as err:
5043
raise AquaValueError(f"Invalid model format: {model_format}") from err
5144
else:

ads/aqua/model/entities.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from ads.aqua.data import AquaResourceIdentifier
2424
from ads.aqua.model.enums import FineTuningDefinedMetadata
2525
from ads.aqua.training.exceptions import exit_code_dict
26-
from ads.aqua.ui import ModelFormat
2726
from ads.common.serializer import DataClassSerializable
2827
from ads.common.utils import get_log_links
2928
from ads.model.datascience_model import DataScienceModel
@@ -46,7 +45,7 @@ class AquaFineTuneValidation(DataClassSerializable):
4645
@dataclass(repr=False)
4746
class ModelValidationResult:
4847
model_file: Optional[str] = None
49-
model_formats: List[ModelFormat] = field(default_factory=list)
48+
model_formats: List[str] = field(default_factory=list)
5049
telemetry_model_name: str = None
5150
tags: Optional[dict] = None
5251

@@ -89,7 +88,7 @@ class AquaModelSummary(DataClassSerializable):
8988
nvidia_gpu_supported: bool = False
9089
arm_cpu_supported: bool = False
9190
model_file: Optional[str] = None
92-
model_formats: List[ModelFormat] = field(default_factory=list)
91+
model_formats: List[str] = field(default_factory=list)
9392

9493

9594
@dataclass(repr=False)

0 commit comments

Comments
 (0)