Skip to content

Commit ac49a3e

Browse files
Resolving conflicts
2 parents df5037e + 3d5b790 commit ac49a3e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+104809
-444
lines changed

.github/workflows/run-operators-unit-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727
strategy:
2828
fail-fast: false
2929
matrix:
30-
python-version: ["3.8", "3.9", "3.10"]
30+
python-version: ["3.8", "3.9", "3.10", "3.11"]
3131

3232
steps:
3333
- uses: actions/checkout@v4

.github/workflows/run-unittests-default_setup.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: "[Py3.8][Py3.9][Py3.10] tests/unitary/default_setup/**"
1+
name: "[Py3.8-3.11] - Default Tests"
22

33
on:
44
workflow_dispatch:
@@ -33,7 +33,7 @@ jobs:
3333
strategy:
3434
fail-fast: false
3535
matrix:
36-
python-version: ["3.8", "3.9", "3.10"]
36+
python-version: ["3.8", "3.9", "3.10", "3.11"]
3737

3838
steps:
3939
- uses: actions/checkout@v4

.github/workflows/run-unittests-py38-cov-report.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: "[Py3.8][COV REPORT] tests/unitary/**"
1+
name: "[Py3.8][COV REPORT] - All Unit Tests"
22

33
on:
44
workflow_dispatch:
@@ -44,6 +44,7 @@ jobs:
4444
ignore-path: |
4545
--ignore tests/unitary/with_extras/model \
4646
--ignore tests/unitary/with_extras/feature_store \
47+
--ignore tests/unitary/with_extras/operator/feature-store \
4748
--ignore tests/unitary/with_extras/operator/forecast \
4849
--ignore tests/unitary/with_extras/hpo
4950
- name: "slow_tests"

.github/workflows/run-unittests-py39-py310.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: "[Py3.9][Py3.10] - tests/unitary/**"
1+
name: "[Py3.9-3.11] - All Unit Tests"
22

33
on:
44
workflow_dispatch:
@@ -33,7 +33,7 @@ jobs:
3333
strategy:
3434
fail-fast: false
3535
matrix:
36-
python-version: ["3.9", "3.10"]
36+
python-version: ["3.9", "3.10", "3.11"]
3737
name: ["unitary", "slow_tests"]
3838
include:
3939
- name: "unitary"
@@ -46,6 +46,7 @@ jobs:
4646
ignore-path: |
4747
--ignore tests/unitary/with_extras/model \
4848
--ignore tests/unitary/with_extras/feature_store \
49+
--ignore tests/unitary/with_extras/operator/feature-store \
4950
--ignore tests/unitary/with_extras/operator/forecast \
5051
--ignore tests/unitary/with_extras/operator/pii \
5152
--ignore tests/unitary/with_extras/hpo

.github/workflows/test-env-setup/action.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ runs:
77
using: composite
88
steps:
99
- shell: bash
10+
env:
11+
# torch > v2.1.0 brings nvidia dependency by default, we want to install torch for cpu for tests.
12+
# Fot that --index-url https://download.pytorch.org/whl/cpu needed - https://pytorch.org/get-started/locally/.
13+
# Setting env variable here instead of flag --index-url (docs: https://pip.pypa.io/en/stable/cli/pip_install/):
14+
PIP_EXTRA_INDEX_URL: "https://download.pytorch.org/whl/cpu"
1015
run: |
1116
set -x # print commands that are executed
1217

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
repos:
22
# ruff
33
- repo: https://github.com/astral-sh/ruff-pre-commit
4-
rev: v0.4.9
4+
rev: v0.5.0
55
hooks:
66
- id: ruff
77
types_or: [ python, pyi, jupyter ]
@@ -13,7 +13,7 @@ repos:
1313
exclude: ^docs/
1414
# Standard hooks
1515
- repo: https://github.com/pre-commit/pre-commit-hooks
16-
rev: v4.4.0
16+
rev: v4.6.0
1717
hooks:
1818
- id: check-ast
1919
exclude: ^docs/
@@ -42,7 +42,7 @@ repos:
4242
files: ^docs/
4343
# Hardcoded secrets and ocids detector
4444
- repo: https://github.com/gitleaks/gitleaks
45-
rev: v8.17.0
45+
rev: v8.18.4
4646
hooks:
4747
- id: gitleaks
4848
exclude: .github/workflows/reusable-actions/set-dummy-conf.yml

ads/aqua/common/utils.py

Lines changed: 77 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54
"""AQUA utils and constants."""
5+
66
import asyncio
77
import base64
88
import json
@@ -19,13 +19,30 @@
1919
import oci
2020
from oci.data_science.models import JobRun, Model
2121

22-
from ads.aqua.common.enums import RqsAdditionalDetails
22+
from ads.aqua.common.enums import (
23+
InferenceContainerParamType,
24+
InferenceContainerType,
25+
RqsAdditionalDetails,
26+
)
2327
from ads.aqua.common.errors import (
2428
AquaFileNotFoundError,
2529
AquaRuntimeError,
2630
AquaValueError,
2731
)
28-
from ads.aqua.constants import *
32+
from ads.aqua.constants import (
33+
AQUA_GA_LIST,
34+
COMPARTMENT_MAPPING_KEY,
35+
CONSOLE_LINK_RESOURCE_TYPE_MAPPING,
36+
CONTAINER_INDEX,
37+
MAXIMUM_ALLOWED_DATASET_IN_BYTE,
38+
MODEL_BY_REFERENCE_OSS_PATH_KEY,
39+
SERVICE_MANAGED_CONTAINER_URI_SCHEME,
40+
SUPPORTED_FILE_FORMATS,
41+
TGI_INFERENCE_RESTRICTED_PARAMS,
42+
UNKNOWN,
43+
UNKNOWN_JSON_STR,
44+
VLLM_INFERENCE_RESTRICTED_PARAMS,
45+
)
2946
from ads.aqua.data import AquaResourceIdentifier
3047
from ads.common.auth import default_signer
3148
from ads.common.decorator.threaded import threaded
@@ -74,15 +91,15 @@ def get_status(evaluation_status: str, job_run_status: str = None):
7491

7592
status = LifecycleStatus.UNKNOWN
7693
if evaluation_status == Model.LIFECYCLE_STATE_ACTIVE:
77-
if (
78-
job_run_status == JobRun.LIFECYCLE_STATE_IN_PROGRESS
79-
or job_run_status == JobRun.LIFECYCLE_STATE_ACCEPTED
80-
):
94+
if job_run_status in {
95+
JobRun.LIFECYCLE_STATE_IN_PROGRESS,
96+
JobRun.LIFECYCLE_STATE_ACCEPTED,
97+
}:
8198
status = JobRun.LIFECYCLE_STATE_IN_PROGRESS
82-
elif (
83-
job_run_status == JobRun.LIFECYCLE_STATE_FAILED
84-
or job_run_status == JobRun.LIFECYCLE_STATE_NEEDS_ATTENTION
85-
):
99+
elif job_run_status in {
100+
JobRun.LIFECYCLE_STATE_FAILED,
101+
JobRun.LIFECYCLE_STATE_NEEDS_ATTENTION,
102+
}:
86103
status = JobRun.LIFECYCLE_STATE_FAILED
87104
else:
88105
status = job_run_status
@@ -199,10 +216,7 @@ def read_file(file_path: str, **kwargs) -> str:
199216
@threaded()
200217
def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
201218
artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
202-
if artifact_path.startswith("oci://"):
203-
signer = default_signer()
204-
else:
205-
signer = {}
219+
signer = default_signer() if artifact_path.startswith("oci://") else {}
206220
config = json.loads(
207221
read_file(file_path=artifact_path, auth=signer, **kwargs) or UNKNOWN_JSON_STR
208222
)
@@ -448,7 +462,7 @@ def _build_resource_identifier(
448462

449463

450464
def _get_experiment_info(
451-
model: Union[oci.resource_search.models.ResourceSummary, DataScienceModel]
465+
model: Union[oci.resource_search.models.ResourceSummary, DataScienceModel],
452466
) -> tuple:
453467
"""Returns ocid and name of the experiment."""
454468
return (
@@ -609,7 +623,7 @@ def extract_id_and_name_from_tag(tag: str):
609623
base_model_name = UNKNOWN
610624
try:
611625
base_model_ocid, base_model_name = tag.split("#")
612-
except:
626+
except Exception:
613627
pass
614628

615629
if not (is_valid_ocid(base_model_ocid) and base_model_name):
@@ -646,7 +660,7 @@ def get_resource_name(ocid: str) -> str:
646660
try:
647661
resource = query_resource(ocid, return_all=False)
648662
name = resource.display_name if resource else UNKNOWN
649-
except:
663+
except Exception:
650664
name = UNKNOWN
651665
return name
652666

@@ -670,8 +684,8 @@ def get_model_by_reference_paths(model_file_description: dict):
670684

671685
if not models:
672686
raise AquaValueError(
673-
f"Model path is not available in the model json artifact. "
674-
f"Please check if the model created by reference has the correct artifact."
687+
"Model path is not available in the model json artifact. "
688+
"Please check if the model created by reference has the correct artifact."
675689
)
676690

677691
if len(models) > 0:
@@ -848,3 +862,46 @@ def copy_model_config(artifact_path: str, os_path: str, auth: dict = None):
848862
except Exception as ex:
849863
logger.debug(ex)
850864
logger.debug(f"Failed to copy config folder from {artifact_path} to {os_path}.")
865+
866+
867+
def get_container_params_type(container_type_name: str) -> str:
868+
"""The utility function accepts the deployment container type name and returns the corresponding params name.
869+
Parameters
870+
----------
871+
container_type_name: str
872+
type of deployment container, like odsc-vllm-serving or odsc-tgi-serving.
873+
874+
Returns
875+
-------
876+
InferenceContainerParamType value
877+
878+
"""
879+
# check substring instead of direct match in case container_type_name changes in the future
880+
if InferenceContainerType.CONTAINER_TYPE_VLLM in container_type_name.lower():
881+
return InferenceContainerParamType.PARAM_TYPE_VLLM
882+
elif InferenceContainerType.CONTAINER_TYPE_TGI in container_type_name.lower():
883+
return InferenceContainerParamType.PARAM_TYPE_TGI
884+
else:
885+
return UNKNOWN
886+
887+
888+
def get_restricted_params_by_container(container_type_name: str) -> set:
889+
"""The utility function accepts the deployment container type name and returns a set of restricted params
890+
for that container.
891+
Parameters
892+
----------
893+
container_type_name: str
894+
type of deployment container, like odsc-vllm-serving or odsc-tgi-serving.
895+
896+
Returns
897+
-------
898+
Set of restricted params based on container type
899+
900+
"""
901+
# check substring instead of direct match in case container_type_name changes in the future
902+
if InferenceContainerType.CONTAINER_TYPE_VLLM in container_type_name.lower():
903+
return VLLM_INFERENCE_RESTRICTED_PARAMS
904+
elif InferenceContainerType.CONTAINER_TYPE_TGI in container_type_name.lower():
905+
return TGI_INFERENCE_RESTRICTED_PARAMS
906+
else:
907+
return set()

ads/aqua/constants.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54
"""This module defines constants used in ads.aqua module."""
@@ -45,19 +44,33 @@
4544
SUPPORTED_FILE_FORMATS = ["jsonl"]
4645
MODEL_BY_REFERENCE_OSS_PATH_KEY = "artifact_location"
4746

48-
CONSOLE_LINK_RESOURCE_TYPE_MAPPING = dict(
49-
datasciencemodel="models",
50-
datasciencemodeldeployment="model-deployments",
51-
datasciencemodeldeploymentdev="model-deployments",
52-
datasciencemodeldeploymentint="model-deployments",
53-
datasciencemodeldeploymentpre="model-deployments",
54-
datasciencejob="jobs",
55-
datasciencejobrun="job-runs",
56-
datasciencejobrundev="job-runs",
57-
datasciencejobrunint="job-runs",
58-
datasciencejobrunpre="job-runs",
59-
datasciencemodelversionset="model-version-sets",
60-
datasciencemodelversionsetpre="model-version-sets",
61-
datasciencemodelversionsetint="model-version-sets",
62-
datasciencemodelversionsetdev="model-version-sets",
63-
)
47+
CONSOLE_LINK_RESOURCE_TYPE_MAPPING = {
48+
"datasciencemodel": "models",
49+
"datasciencemodeldeployment": "model-deployments",
50+
"datasciencemodeldeploymentdev": "model-deployments",
51+
"datasciencemodeldeploymentint": "model-deployments",
52+
"datasciencemodeldeploymentpre": "model-deployments",
53+
"datasciencejob": "jobs",
54+
"datasciencejobrun": "job-runs",
55+
"datasciencejobrundev": "job-runs",
56+
"datasciencejobrunint": "job-runs",
57+
"datasciencejobrunpre": "job-runs",
58+
"datasciencemodelversionset": "model-version-sets",
59+
"datasciencemodelversionsetpre": "model-version-sets",
60+
"datasciencemodelversionsetint": "model-version-sets",
61+
"datasciencemodelversionsetdev": "model-version-sets",
62+
}
63+
64+
VLLM_INFERENCE_RESTRICTED_PARAMS = {
65+
"--port",
66+
"--host",
67+
"--served-model-name",
68+
"--seed",
69+
}
70+
TGI_INFERENCE_RESTRICTED_PARAMS = {
71+
"--port",
72+
"--hostname",
73+
"--num-shard",
74+
"--sharded",
75+
"--trust-remote-code",
76+
}

0 commit comments

Comments
 (0)