Skip to content

Commit c1b1b96

Browse files
Merge branch 'main' into feature/aqua_ms_changes_2
2 parents 0813099 + 788920c commit c1b1b96

File tree

20 files changed

+992
-317
lines changed

20 files changed

+992
-317
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
name: "Forecast Explainer Tests"
2+
3+
on:
4+
workflow_dispatch:
5+
pull_request:
6+
branches: [ "main", "operators/**" ]
7+
8+
# Cancel in progress workflows on pull_requests.
9+
# https://docs.github.com/en/actions/using-jobs/using-concurrency#example-using-a-fallback-value
10+
concurrency:
11+
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
12+
cancel-in-progress: true
13+
14+
permissions:
15+
contents: read
16+
17+
env:
18+
SEGMENT_DOWNLOAD_TIMEOUT_MINS: 5
19+
20+
jobs:
21+
test:
22+
name: python ${{ matrix.python-version }}
23+
runs-on: ubuntu-latest
24+
timeout-minutes: 180
25+
26+
strategy:
27+
fail-fast: false
28+
matrix:
29+
python-version: ["3.10", "3.11"]
30+
31+
steps:
32+
- uses: actions/checkout@v4
33+
with:
34+
fetch-depth: 0
35+
ref: ${{ github.event.pull_request.head.sha }}
36+
37+
- uses: actions/setup-python@v5
38+
with:
39+
python-version: ${{ matrix.python-version }}
40+
cache: "pip"
41+
cache-dependency-path: |
42+
pyproject.toml
43+
"**requirements.txt"
44+
"test-requirements-operators.txt"
45+
46+
- uses: ./.github/workflows/set-dummy-conf
47+
name: "Test config setup"
48+
49+
- name: "Run Forecast Explainer Tests"
50+
timeout-minutes: 180
51+
shell: bash
52+
run: |
53+
set -x # print commands that are executed
54+
$CONDA/bin/conda init
55+
source /home/runner/.bashrc
56+
pip install -r test-requirements-operators.txt
57+
pip install "oracle-automlx[forecasting]>=25.1.1"
58+
pip install pandas>=2.2.0
59+
python -m pytest -v -p no:warnings --durations=5 tests/operators/forecast/test_explainers.py

.github/workflows/run-forecast-unit-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,4 @@ jobs:
5858
pip install -r test-requirements-operators.txt
5959
pip install "oracle-automlx[forecasting]>=25.1.1"
6060
pip install pandas>=2.2.0
61-
python -m pytest -v -p no:warnings --durations=5 tests/operators/forecast
61+
python -m pytest -v -p no:warnings --durations=5 tests/operators/forecast --ignore=tests/operators/forecast/test_explainers.py

ads/aqua/common/enums.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,6 @@
22
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

5-
"""
6-
aqua.common.enums
7-
~~~~~~~~~~~~~~
8-
This module contains the set of enums used in AQUA.
9-
"""
10-
115
from ads.common.extended_enum import ExtendedEnum
126

137

@@ -88,7 +82,8 @@ class RqsAdditionalDetails(ExtendedEnum):
8882

8983
class TextEmbeddingInferenceContainerParams(ExtendedEnum):
9084
"""Contains a subset of params that are required for enabling model deployment in OCI Data Science. More options
91-
are available at https://huggingface.co/docs/text-embeddings-inference/en/cli_arguments"""
85+
are available at https://huggingface.co/docs/text-embeddings-inference/en/cli_arguments
86+
"""
9287

9388
MODEL_ID = "model-id"
9489
PORT = "port"
@@ -97,3 +92,14 @@ class TextEmbeddingInferenceContainerParams(ExtendedEnum):
9792
class ConfigFolder(ExtendedEnum):
9893
CONFIG = "config"
9994
ARTIFACT = "artifact"
95+
96+
97+
class ModelFormat(ExtendedEnum):
98+
GGUF = "GGUF"
99+
SAFETENSORS = "SAFETENSORS"
100+
UNKNOWN = "UNKNOWN"
101+
102+
103+
class Platform(ExtendedEnum):
104+
ARM_CPU = "ARM_CPU"
105+
NVIDIA_GPU = "NVIDIA_GPU"

ads/aqua/common/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ def service_config_path():
548548
return f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config"
549549

550550

551-
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
551+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=10), timer=datetime.now))
552552
def get_container_config():
553553
config = load_config(
554554
file_path=service_config_path(),

ads/aqua/config/container_config.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2025 Oracle and/or its affiliates.
3+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4+
5+
from typing import Dict, List, Optional
6+
7+
from pydantic import Field
8+
9+
from ads.aqua.common.entities import ContainerSpec
10+
from ads.aqua.config.utils.serializer import Serializable
11+
12+
13+
class AquaContainerConfigSpec(Serializable):
14+
"""
15+
Represents container specification details.
16+
17+
Attributes
18+
----------
19+
cli_param (Optional[str]): CLI parameter for container configuration.
20+
server_port (Optional[str]): The server port for the container.
21+
health_check_port (Optional[str]): The health check port for the container.
22+
env_vars (Optional[List[Dict]]): Environment variables for the container.
23+
restricted_params (Optional[List[str]]): Restricted parameters for container configuration.
24+
"""
25+
26+
cli_param: Optional[str] = Field(
27+
default=None, description="CLI parameter for container configuration."
28+
)
29+
server_port: Optional[str] = Field(
30+
default=None, description="Server port for the container."
31+
)
32+
health_check_port: Optional[str] = Field(
33+
default=None, description="Health check port for the container."
34+
)
35+
env_vars: Optional[List[Dict]] = Field(
36+
default_factory=list, description="List of environment variables."
37+
)
38+
restricted_params: Optional[List[str]] = Field(
39+
default_factory=list, description="List of restricted parameters."
40+
)
41+
42+
class Config:
43+
extra = "allow"
44+
45+
46+
class AquaContainerConfigItem(Serializable):
47+
"""
48+
Represents an item of the AQUA container configuration.
49+
50+
Attributes
51+
----------
52+
name (Optional[str]): Name of the container configuration item.
53+
version (Optional[str]): Version of the container.
54+
display_name (Optional[str]): Display name for UI.
55+
family (Optional[str]): Container family or category.
56+
platforms (Optional[List[str]]): Supported platforms.
57+
model_formats (Optional[List[str]]): Supported model formats.
58+
spec (Optional[AquaContainerConfigSpec]): Container specification details.
59+
"""
60+
61+
name: Optional[str] = Field(
62+
default=None, description="Name of the container configuration item."
63+
)
64+
version: Optional[str] = Field(
65+
default=None, description="Version of the container."
66+
)
67+
display_name: Optional[str] = Field(
68+
default=None, description="Display name of the container."
69+
)
70+
family: Optional[str] = Field(
71+
default=None, description="Container family or category."
72+
)
73+
platforms: Optional[List[str]] = Field(
74+
default_factory=list, description="Supported platforms."
75+
)
76+
model_formats: Optional[List[str]] = Field(
77+
default_factory=list, description="Supported model formats."
78+
)
79+
spec: Optional[AquaContainerConfigSpec] = Field(
80+
default_factory=AquaContainerConfigSpec,
81+
description="Detailed container specification.",
82+
)
83+
usages: Optional[List[str]] = Field(
84+
default_factory=list, description="Supported usages."
85+
)
86+
87+
class Config:
88+
extra = "allow"
89+
90+
91+
class AquaContainerConfig(Serializable):
92+
"""
93+
Represents a configuration of AQUA containers to be returned to the client.
94+
95+
Attributes
96+
----------
97+
inference (Dict[str, AquaContainerConfigItem]): Inference container configuration items.
98+
finetune (Dict[str, AquaContainerConfigItem]): Fine-tuning container configuration items.
99+
evaluate (Dict[str, AquaContainerConfigItem]): Evaluation container configuration items.
100+
"""
101+
102+
inference: Dict[str, AquaContainerConfigItem] = Field(
103+
default_factory=dict, description="Inference container configuration items."
104+
)
105+
finetune: Dict[str, AquaContainerConfigItem] = Field(
106+
default_factory=dict, description="Fine-tuning container configuration items."
107+
)
108+
evaluate: Dict[str, AquaContainerConfigItem] = Field(
109+
default_factory=dict, description="Evaluation container configuration items."
110+
)
111+
112+
def to_dict(self):
113+
return {
114+
"inference": list(self.inference.values()),
115+
"finetune": list(self.finetune.values()),
116+
"evaluate": list(self.evaluate.values()),
117+
}
118+
119+
@classmethod
120+
def from_container_index_json(
121+
cls,
122+
config: Dict,
123+
enable_spec: Optional[bool] = False,
124+
) -> "AquaContainerConfig":
125+
"""
126+
Creates an AquaContainerConfig instance from a container index JSON.
127+
128+
Parameters
129+
----------
130+
config (Optional[Dict]): The container index JSON.
131+
enable_spec (Optional[bool]): If True, fetch container specification details.
132+
133+
Returns
134+
-------
135+
AquaContainerConfig: The constructed container configuration.
136+
"""
137+
# TODO: Return this logic back if necessary in the next iteraion.
138+
# if not config:
139+
# config = get_container_config()
140+
141+
inference_items: Dict[str, AquaContainerConfigItem] = {}
142+
finetune_items: Dict[str, AquaContainerConfigItem] = {}
143+
evaluate_items: Dict[str, AquaContainerConfigItem] = {}
144+
145+
for container_type, containers in config.items():
146+
if isinstance(containers, list):
147+
for container in containers:
148+
platforms = container.get("platforms", [])
149+
model_formats = container.get("modelFormats", [])
150+
usages = container.get("usages", [])
151+
container_spec = (
152+
config.get(ContainerSpec.CONTAINER_SPEC, {}).get(
153+
container_type, {}
154+
)
155+
if enable_spec
156+
else None
157+
)
158+
container_item = AquaContainerConfigItem(
159+
name=container.get("name", ""),
160+
version=container.get("version", ""),
161+
display_name=container.get(
162+
"displayName", container.get("version", "")
163+
),
164+
family=container_type,
165+
platforms=platforms,
166+
model_formats=model_formats,
167+
usages=usages,
168+
spec=(
169+
AquaContainerConfigSpec(
170+
cli_param=container_spec.get(
171+
ContainerSpec.CLI_PARM, ""
172+
),
173+
server_port=container_spec.get(
174+
ContainerSpec.SERVER_PORT, ""
175+
),
176+
health_check_port=container_spec.get(
177+
ContainerSpec.HEALTH_CHECK_PORT, ""
178+
),
179+
env_vars=container_spec.get(ContainerSpec.ENV_VARS, []),
180+
restricted_params=container_spec.get(
181+
ContainerSpec.RESTRICTED_PARAMS, []
182+
),
183+
)
184+
if container_spec
185+
else None
186+
),
187+
)
188+
if container.get("type") == "inference":
189+
inference_items[container_type] = container_item
190+
elif (
191+
container.get("type") == "fine-tune"
192+
or container_type == "odsc-llm-fine-tuning"
193+
):
194+
finetune_items[container_type] = container_item
195+
elif (
196+
container.get("type") == "evaluate"
197+
or container_type == "odsc-llm-evaluate"
198+
):
199+
evaluate_items[container_type] = container_item
200+
201+
return cls(
202+
inference=inference_items, finetune=finetune_items, evaluate=evaluate_items
203+
)

ads/aqua/evaluation/evaluation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,13 @@
4040
from ads.aqua.common.utils import (
4141
extract_id_and_name_from_tag,
4242
fire_and_forget,
43+
get_container_config,
4344
get_container_image,
4445
is_valid_ocid,
4546
upload_local_to_os,
4647
)
4748
from ads.aqua.config.config import get_evaluation_service_config
49+
from ads.aqua.config.container_config import AquaContainerConfig
4850
from ads.aqua.constants import (
4951
CONSOLE_LINK_RESOURCE_TYPE_MAPPING,
5052
EVALUATION_REPORT,
@@ -74,7 +76,6 @@
7476
CreateAquaEvaluationDetails,
7577
)
7678
from ads.aqua.evaluation.errors import EVALUATION_JOB_EXIT_CODE_MESSAGE
77-
from ads.aqua.ui import AquaContainerConfig
7879
from ads.common.auth import default_signer
7980
from ads.common.object_storage_details import ObjectStorageDetails
8081
from ads.common.utils import UNKNOWN, get_console_link, get_files, get_log_links
@@ -191,7 +192,7 @@ def create(
191192
evaluation_source.runtime.to_dict()
192193
)
193194
inference_config = AquaContainerConfig.from_container_index_json(
194-
enable_spec=True
195+
config=get_container_config(), enable_spec=True
195196
).inference
196197
for container in inference_config.values():
197198
if container.name == runtime.image[: runtime.image.rfind(":")]:

0 commit comments

Comments
 (0)