Skip to content

Commit ed67206

Browse files
authored
Merge branch 'main' into ODSC-64831/optimize-auto-select
2 parents dcd58c9 + 35c03c6 commit ed67206

21 files changed

+749
-423
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
name: "Forecast Explainer Tests"
2+
3+
on:
4+
workflow_dispatch:
5+
pull_request:
6+
branches: [ "main", "operators/**" ]
7+
8+
# Cancel in progress workflows on pull_requests.
9+
# https://docs.github.com/en/actions/using-jobs/using-concurrency#example-using-a-fallback-value
10+
concurrency:
11+
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
12+
cancel-in-progress: true
13+
14+
permissions:
15+
contents: read
16+
17+
env:
18+
SEGMENT_DOWNLOAD_TIMEOUT_MINS: 5
19+
20+
jobs:
21+
test:
22+
name: python ${{ matrix.python-version }}
23+
runs-on: ubuntu-latest
24+
timeout-minutes: 180
25+
26+
strategy:
27+
fail-fast: false
28+
matrix:
29+
python-version: ["3.10", "3.11"]
30+
31+
steps:
32+
- uses: actions/checkout@v4
33+
with:
34+
fetch-depth: 0
35+
ref: ${{ github.event.pull_request.head.sha }}
36+
37+
- uses: actions/setup-python@v5
38+
with:
39+
python-version: ${{ matrix.python-version }}
40+
cache: "pip"
41+
cache-dependency-path: |
42+
pyproject.toml
43+
"**requirements.txt"
44+
"test-requirements-operators.txt"
45+
46+
- uses: ./.github/workflows/set-dummy-conf
47+
name: "Test config setup"
48+
49+
- name: "Run Forecast Explainer Tests"
50+
timeout-minutes: 180
51+
shell: bash
52+
run: |
53+
set -x # print commands that are executed
54+
$CONDA/bin/conda init
55+
source /home/runner/.bashrc
56+
pip install -r test-requirements-operators.txt
57+
pip install "oracle-automlx[forecasting]>=25.1.1"
58+
pip install pandas>=2.2.0
59+
python -m pytest -v -p no:warnings --durations=5 tests/operators/forecast/test_explainers.py

.github/workflows/run-forecast-unit-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,4 @@ jobs:
5858
pip install -r test-requirements-operators.txt
5959
pip install "oracle-automlx[forecasting]>=25.1.1"
6060
pip install pandas>=2.2.0
61-
python -m pytest -v -p no:warnings --durations=5 tests/operators/forecast
61+
python -m pytest -v -p no:warnings --durations=5 tests/operators/forecast --ignore=tests/operators/forecast/test_explainers.py

ads/aqua/app.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
import os
77
import traceback
88
from dataclasses import fields
9-
from typing import Dict, Optional, Union
9+
from typing import Any, Dict, Optional, Union
1010

1111
import oci
1212
from oci.data_science.models import UpdateModelDetails, UpdateModelProvenanceDetails
1313

1414
from ads import set_auth
1515
from ads.aqua import logger
16+
from ads.aqua.common.entities import ModelConfigResult
1617
from ads.aqua.common.enums import ConfigFolder, Tags
1718
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1819
from ads.aqua.common.utils import (
@@ -273,24 +274,24 @@ def get_config(
273274
model_id: str,
274275
config_file_name: str,
275276
config_folder: Optional[str] = ConfigFolder.CONFIG,
276-
) -> Dict:
277-
"""Gets the config for the given Aqua model.
277+
) -> ModelConfigResult:
278+
"""
279+
Gets the configuration for the given Aqua model along with the model details.
278280
279281
Parameters
280282
----------
281-
model_id: str
283+
model_id : str
282284
The OCID of the Aqua model.
283-
config_file_name: str
284-
name of the config file
285-
config_folder: (str, optional):
286-
subfolder path where config_file_name needs to be searched
287-
Defaults to `ConfigFolder.CONFIG`.
288-
When searching inside model artifact directory , the value is ConfigFolder.ARTIFACT`
285+
config_file_name : str
286+
The name of the configuration file.
287+
config_folder : Optional[str]
288+
The subfolder path where config_file_name is searched.
289+
Defaults to ConfigFolder.CONFIG. For model artifact directories, use ConfigFolder.ARTIFACT.
289290
290291
Returns
291292
-------
292-
Dict:
293-
A dict of allowed configs.
293+
ModelConfigResult
294+
A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
294295
"""
295296
config_folder = config_folder or ConfigFolder.CONFIG
296297
oci_model = self.ds_client.get_model(model_id).data
@@ -302,11 +303,11 @@ def get_config(
302303
if oci_model.freeform_tags
303304
else False
304305
)
305-
306306
if not oci_aqua:
307-
raise AquaRuntimeError(f"Target model {oci_model.id} is not Aqua model.")
307+
raise AquaRuntimeError(f"Target model {oci_model.id} is not an Aqua model.")
308+
309+
config: Dict[str, Any] = {}
308310

309-
config = {}
310311
# if the current model has a service model tag, then
311312
if Tags.AQUA_SERVICE_MODEL_TAG in oci_model.freeform_tags:
312313
base_model_ocid = oci_model.freeform_tags[Tags.AQUA_SERVICE_MODEL_TAG]
@@ -326,7 +327,7 @@ def get_config(
326327
logger.debug(
327328
f"Failed to get artifact path from custom metadata for the model: {model_id}"
328329
)
329-
return config
330+
return ModelConfigResult(config=config, model_details=oci_model)
330331

331332
config_path = os.path.join(os.path.dirname(artifact_path), config_folder)
332333
if not is_path_exists(config_path):
@@ -351,9 +352,8 @@ def get_config(
351352
f"{config_file_name} is not available for the model: {model_id}. "
352353
f"Check if the custom metadata has the artifact path set."
353354
)
354-
return config
355355

356-
return config
356+
return ModelConfigResult(config=config, model_details=oci_model)
357357

358358
@property
359359
def telemetry(self):
@@ -375,9 +375,11 @@ def build_cli(self) -> str:
375375
"""
376376
cmd = f"ads aqua {self._command}"
377377
params = [
378-
f"--{field.name} {json.dumps(getattr(self, field.name))}"
379-
if isinstance(getattr(self, field.name), dict)
380-
else f"--{field.name} {getattr(self, field.name)}"
378+
(
379+
f"--{field.name} {json.dumps(getattr(self, field.name))}"
380+
if isinstance(getattr(self, field.name), dict)
381+
else f"--{field.name} {getattr(self, field.name)}"
382+
)
381383
for field in fields(self.__class__)
382384
if getattr(self, field.name) is not None
383385
]

ads/aqua/common/entities.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

5+
from typing import Any, Dict, Optional
6+
7+
from oci.data_science.models import Model
8+
from pydantic import BaseModel, Field
9+
510

611
class ContainerSpec:
712
"""
@@ -15,3 +20,25 @@ class ContainerSpec:
1520
ENV_VARS = "envVars"
1621
RESTRICTED_PARAMS = "restrictedParams"
1722
EVALUATION_CONFIGURATION = "evaluationConfiguration"
23+
24+
25+
class ModelConfigResult(BaseModel):
26+
"""
27+
Represents the result of getting the AQUA model configuration.
28+
29+
Attributes:
30+
model_details (Dict[str, Any]): A dictionary containing model details extracted from OCI.
31+
config (Dict[str, Any]): A dictionary of the loaded configuration.
32+
"""
33+
34+
config: Optional[Dict[str, Any]] = Field(
35+
None, description="Loaded configuration dictionary."
36+
)
37+
model_details: Optional[Model] = Field(
38+
None, description="Details of the model from OCI."
39+
)
40+
41+
class Config:
42+
extra = "ignore"
43+
arbitrary_types_allowed = True
44+
protected_namespaces = ()

ads/aqua/config/container_config.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ class AquaContainerConfigItem(Serializable):
8080
default_factory=AquaContainerConfigSpec,
8181
description="Detailed container specification.",
8282
)
83+
usages: Optional[List[str]] = Field(
84+
default_factory=list, description="Supported usages."
85+
)
8386

8487
class Config:
8588
extra = "allow"
@@ -131,7 +134,7 @@ def from_container_index_json(
131134
-------
132135
AquaContainerConfig: The constructed container configuration.
133136
"""
134-
#TODO: Return this logic back if necessary in the next iteraion.
137+
# TODO: Return this logic back if necessary in the next iteraion.
135138
# if not config:
136139
# config = get_container_config()
137140

@@ -144,6 +147,7 @@ def from_container_index_json(
144147
for container in containers:
145148
platforms = container.get("platforms", [])
146149
model_formats = container.get("modelFormats", [])
150+
usages = container.get("usages", [])
147151
container_spec = (
148152
config.get(ContainerSpec.CONTAINER_SPEC, {}).get(
149153
container_type, {}
@@ -160,6 +164,7 @@ def from_container_index_json(
160164
family=container_type,
161165
platforms=platforms,
162166
model_formats=model_formats,
167+
usages=usages,
163168
spec=(
164169
AquaContainerConfigSpec(
165170
cli_param=container_spec.get(

0 commit comments

Comments
 (0)