Skip to content

Commit 4628262

Browse files
authored
Merge branch 'main' into forecasting/ODSC68407/output_format
2 parents 65eb5a3 + c7d3b3d commit 4628262

File tree

65 files changed

+2490
-924
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+2490
-924
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
name: "Forecast Explainer Tests"
2+
3+
on:
4+
workflow_dispatch:
5+
pull_request:
6+
branches: [ "main", "operators/**" ]
7+
8+
# Cancel in progress workflows on pull_requests.
9+
# https://docs.github.com/en/actions/using-jobs/using-concurrency#example-using-a-fallback-value
10+
concurrency:
11+
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
12+
cancel-in-progress: true
13+
14+
permissions:
15+
contents: read
16+
17+
env:
18+
SEGMENT_DOWNLOAD_TIMEOUT_MINS: 5
19+
20+
jobs:
21+
test:
22+
name: python ${{ matrix.python-version }}
23+
runs-on: ubuntu-latest
24+
timeout-minutes: 180
25+
26+
strategy:
27+
fail-fast: false
28+
matrix:
29+
python-version: ["3.10", "3.11"]
30+
31+
steps:
32+
- uses: actions/checkout@v4
33+
with:
34+
fetch-depth: 0
35+
ref: ${{ github.event.pull_request.head.sha }}
36+
37+
- uses: actions/setup-python@v5
38+
with:
39+
python-version: ${{ matrix.python-version }}
40+
cache: "pip"
41+
cache-dependency-path: |
42+
pyproject.toml
43+
"**requirements.txt"
44+
"test-requirements-operators.txt"
45+
46+
- uses: ./.github/workflows/set-dummy-conf
47+
name: "Test config setup"
48+
49+
- name: "Run Forecast Explainer Tests"
50+
timeout-minutes: 180
51+
shell: bash
52+
run: |
53+
set -x # print commands that are executed
54+
$CONDA/bin/conda init
55+
source /home/runner/.bashrc
56+
pip install -r test-requirements-operators.txt
57+
pip install "oracle-automlx[forecasting]>=25.1.1"
58+
pip install pandas>=2.2.0
59+
python -m pytest -v -p no:warnings --durations=5 tests/operators/forecast/test_explainers.py

.github/workflows/run-forecast-unit-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,4 @@ jobs:
5858
pip install -r test-requirements-operators.txt
5959
pip install "oracle-automlx[forecasting]>=25.1.1"
6060
pip install pandas>=2.2.0
61-
python -m pytest -v -p no:warnings --durations=5 tests/operators/forecast
61+
python -m pytest -v -p no:warnings --durations=5 tests/operators/forecast --ignore=tests/operators/forecast/test_explainers.py

ads/aqua/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77
from logging import getLogger
88

99
from ads import logger, set_auth
10-
from ads.aqua.client.client import AsyncClient, Client
10+
from ads.aqua.client.client import (
11+
AsyncClient,
12+
Client,
13+
HttpxOCIAuth,
14+
get_async_httpx_client,
15+
get_httpx_client,
16+
)
1117
from ads.aqua.common.utils import fetch_service_compartment
1218
from ads.config import OCI_RESOURCE_PRINCIPAL_VERSION
1319

ads/aqua/app.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
import os
77
import traceback
88
from dataclasses import fields
9-
from typing import Dict, Optional, Union
9+
from typing import Any, Dict, Optional, Union
1010

1111
import oci
1212
from oci.data_science.models import UpdateModelDetails, UpdateModelProvenanceDetails
1313

1414
from ads import set_auth
1515
from ads.aqua import logger
16+
from ads.aqua.common.entities import ModelConfigResult
1617
from ads.aqua.common.enums import ConfigFolder, Tags
1718
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1819
from ads.aqua.common.utils import (
@@ -21,10 +22,9 @@
2122
is_valid_ocid,
2223
load_config,
2324
)
24-
from ads.aqua.constants import UNKNOWN
2525
from ads.common import oci_client as oc
2626
from ads.common.auth import default_signer
27-
from ads.common.utils import extract_region, is_path_exists
27+
from ads.common.utils import UNKNOWN, extract_region, is_path_exists
2828
from ads.config import (
2929
AQUA_TELEMETRY_BUCKET,
3030
AQUA_TELEMETRY_BUCKET_NS,
@@ -273,24 +273,24 @@ def get_config(
273273
model_id: str,
274274
config_file_name: str,
275275
config_folder: Optional[str] = ConfigFolder.CONFIG,
276-
) -> Dict:
277-
"""Gets the config for the given Aqua model.
276+
) -> ModelConfigResult:
277+
"""
278+
Gets the configuration for the given Aqua model along with the model details.
278279
279280
Parameters
280281
----------
281-
model_id: str
282+
model_id : str
282283
The OCID of the Aqua model.
283-
config_file_name: str
284-
name of the config file
285-
config_folder: (str, optional):
286-
subfolder path where config_file_name needs to be searched
287-
Defaults to `ConfigFolder.CONFIG`.
288-
When searching inside model artifact directory , the value is ConfigFolder.ARTIFACT`
284+
config_file_name : str
285+
The name of the configuration file.
286+
config_folder : Optional[str]
287+
The subfolder path where config_file_name is searched.
288+
Defaults to ConfigFolder.CONFIG. For model artifact directories, use ConfigFolder.ARTIFACT.
289289
290290
Returns
291291
-------
292-
Dict:
293-
A dict of allowed configs.
292+
ModelConfigResult
293+
A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
294294
"""
295295
config_folder = config_folder or ConfigFolder.CONFIG
296296
oci_model = self.ds_client.get_model(model_id).data
@@ -302,11 +302,11 @@ def get_config(
302302
if oci_model.freeform_tags
303303
else False
304304
)
305-
306305
if not oci_aqua:
307-
raise AquaRuntimeError(f"Target model {oci_model.id} is not Aqua model.")
306+
raise AquaRuntimeError(f"Target model {oci_model.id} is not an Aqua model.")
307+
308+
config: Dict[str, Any] = {}
308309

309-
config = {}
310310
# if the current model has a service model tag, then
311311
if Tags.AQUA_SERVICE_MODEL_TAG in oci_model.freeform_tags:
312312
base_model_ocid = oci_model.freeform_tags[Tags.AQUA_SERVICE_MODEL_TAG]
@@ -326,7 +326,7 @@ def get_config(
326326
logger.debug(
327327
f"Failed to get artifact path from custom metadata for the model: {model_id}"
328328
)
329-
return config
329+
return ModelConfigResult(config=config, model_details=oci_model)
330330

331331
config_path = os.path.join(os.path.dirname(artifact_path), config_folder)
332332
if not is_path_exists(config_path):
@@ -351,9 +351,8 @@ def get_config(
351351
f"{config_file_name} is not available for the model: {model_id}. "
352352
f"Check if the custom metadata has the artifact path set."
353353
)
354-
return config
355354

356-
return config
355+
return ModelConfigResult(config=config, model_details=oci_model)
357356

358357
@property
359358
def telemetry(self):
@@ -375,9 +374,11 @@ def build_cli(self) -> str:
375374
"""
376375
cmd = f"ads aqua {self._command}"
377376
params = [
378-
f"--{field.name} {json.dumps(getattr(self, field.name))}"
379-
if isinstance(getattr(self, field.name), dict)
380-
else f"--{field.name} {getattr(self, field.name)}"
377+
(
378+
f"--{field.name} {json.dumps(getattr(self, field.name))}"
379+
if isinstance(getattr(self, field.name), dict)
380+
else f"--{field.name} {getattr(self, field.name)}"
381+
)
381382
for field in fields(self.__class__)
382383
if getattr(self, field.name) is not None
383384
]

ads/aqua/client/client.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,22 +51,23 @@
5151
logger = logging.getLogger(__name__)
5252

5353

54-
class OCIAuth(httpx.Auth):
54+
class HttpxOCIAuth(httpx.Auth):
5555
"""
5656
Custom HTTPX authentication class that uses the OCI Signer for request signing.
5757
5858
Attributes:
5959
signer (oci.signer.Signer): The OCI signer used to sign requests.
6060
"""
6161

62-
def __init__(self, signer: oci.signer.Signer):
62+
def __init__(self, signer: Optional[oci.signer.Signer] = None):
6363
"""
64-
Initialize the OCIAuth instance.
64+
Initialize the HttpxOCIAuth instance.
6565
6666
Args:
6767
signer (oci.signer.Signer): The OCI signer to use for signing requests.
6868
"""
69-
self.signer = signer
69+
70+
self.signer = signer or authutil.default_signer().get("signer")
7071

7172
def auth_flow(self, request: httpx.Request) -> Iterator[httpx.Request]:
7273
"""
@@ -256,7 +257,7 @@ def __init__(
256257
auth = auth or authutil.default_signer()
257258
if not callable(auth.get("signer")):
258259
raise ValueError("Auth object must have a 'signer' callable attribute.")
259-
self.auth = OCIAuth(auth["signer"])
260+
self.auth = HttpxOCIAuth(auth["signer"])
260261

261262
logger.debug(
262263
f"Initialized {self.__class__.__name__} with endpoint={self.endpoint}, "
@@ -352,7 +353,7 @@ def __init__(self, *args, **kwargs) -> None:
352353
**kwargs: Keyword arguments forwarded to BaseClient.
353354
"""
354355
super().__init__(*args, **kwargs)
355-
self._client = httpx.Client(timeout=self.timeout)
356+
self._client = httpx.Client(timeout=self.timeout, auth=self.auth)
356357

357358
def is_closed(self) -> bool:
358359
return self._client.is_closed
@@ -400,7 +401,6 @@ def _request(
400401
response = self._client.post(
401402
self.endpoint,
402403
headers=self._prepare_headers(stream=False, headers=headers),
403-
auth=self.auth,
404404
json=payload,
405405
)
406406
logger.debug(f"Received response with status code: {response.status_code}")
@@ -447,7 +447,6 @@ def _stream(
447447
"POST",
448448
self.endpoint,
449449
headers=self._prepare_headers(stream=True, headers=headers),
450-
auth=self.auth,
451450
json={**payload, "stream": True},
452451
) as response:
453452
try:
@@ -581,7 +580,7 @@ def __init__(self, *args, **kwargs) -> None:
581580
**kwargs: Keyword arguments forwarded to BaseClient.
582581
"""
583582
super().__init__(*args, **kwargs)
584-
self._client = httpx.AsyncClient(timeout=self.timeout)
583+
self._client = httpx.AsyncClient(timeout=self.timeout, auth=self.auth)
585584

586585
def is_closed(self) -> bool:
587586
return self._client.is_closed
@@ -637,7 +636,6 @@ async def _request(
637636
response = await self._client.post(
638637
self.endpoint,
639638
headers=self._prepare_headers(stream=False, headers=headers),
640-
auth=self.auth,
641639
json=payload,
642640
)
643641
logger.debug(f"Received response with status code: {response.status_code}")
@@ -683,7 +681,6 @@ async def _stream(
683681
"POST",
684682
self.endpoint,
685683
headers=self._prepare_headers(stream=True, headers=headers),
686-
auth=self.auth,
687684
json={**payload, "stream": True},
688685
) as response:
689686
try:
@@ -797,3 +794,43 @@ async def embeddings(
797794
logger.debug(f"Generating embeddings with input: {input}, payload: {payload}")
798795
payload = {**(payload or {}), "input": input}
799796
return await self._request(payload=payload, headers=headers)
797+
798+
799+
def get_httpx_client(**kwargs: Any) -> httpx.Client:
800+
"""
801+
Creates and returns a synchronous httpx Client configured with OCI authentication signer based
802+
the authentication type setup using ads.set_auth method or env variable OCI_IAM_TYPE.
803+
More information - https://accelerated-data-science.readthedocs.io/en/stable/user_guide/cli/authentication.html
804+
805+
Parameters
806+
----------
807+
**kwargs : Any
808+
Keyword arguments supported by httpx.Client
809+
810+
Returns
811+
-------
812+
Client
813+
A configured synchronous httpx Client instance.
814+
"""
815+
kwargs["auth"] = kwargs.get("auth") or HttpxOCIAuth()
816+
return httpx.Client(**kwargs)
817+
818+
819+
def get_async_httpx_client(**kwargs: Any) -> httpx.AsyncClient:
820+
"""
821+
Creates and returns a synchronous httpx Client configured with OCI authentication signer based
822+
the authentication type setup using ads.set_auth method or env variable OCI_IAM_TYPE.
823+
More information - https://accelerated-data-science.readthedocs.io/en/stable/user_guide/cli/authentication.html
824+
825+
Parameters
826+
----------
827+
**kwargs : Any
828+
Keyword arguments supported by httpx.Client
829+
830+
Returns
831+
-------
832+
AsyncClient
833+
A configured asynchronous httpx AsyncClient instance.
834+
"""
835+
kwargs["auth"] = kwargs.get("auth") or HttpxOCIAuth()
836+
return httpx.AsyncClient(**kwargs)

ads/aqua/common/entities.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

5+
from typing import Any, Dict, Optional
6+
7+
from oci.data_science.models import Model
8+
from pydantic import BaseModel, Field
9+
510

611
class ContainerSpec:
712
"""
@@ -15,3 +20,25 @@ class ContainerSpec:
1520
ENV_VARS = "envVars"
1621
RESTRICTED_PARAMS = "restrictedParams"
1722
EVALUATION_CONFIGURATION = "evaluationConfiguration"
23+
24+
25+
class ModelConfigResult(BaseModel):
26+
"""
27+
Represents the result of getting the AQUA model configuration.
28+
29+
Attributes:
30+
model_details (Dict[str, Any]): A dictionary containing model details extracted from OCI.
31+
config (Dict[str, Any]): A dictionary of the loaded configuration.
32+
"""
33+
34+
config: Optional[Dict[str, Any]] = Field(
35+
None, description="Loaded configuration dictionary."
36+
)
37+
model_details: Optional[Model] = Field(
38+
None, description="Details of the model from OCI."
39+
)
40+
41+
class Config:
42+
extra = "ignore"
43+
arbitrary_types_allowed = True
44+
protected_namespaces = ()

0 commit comments

Comments
 (0)