Skip to content

Commit 7c019ca

Browse files
Merge branch 'main' into feature/aqua-v1.0.5
2 parents daef898 + 60d48f2 commit 7c019ca

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1302
-544
lines changed

.github/workflows/run-forecast-unit-tests.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ jobs:
5656
$CONDA/bin/conda init
5757
source /home/runner/.bashrc
5858
pip install -r test-requirements-operators.txt
59-
pip install "oracle-automlx[classic]>=24.2.0"
60-
pip install "oracle-automlx[forecasting]>=24.2.0"
59+
pip install "oracle-automlx[forecasting]>=24.4.0"
6160
pip install pandas>=2.2.0
6261
python -m pytest -v -p no:warnings --durations=5 tests/operators/forecast

ads/aqua/common/enums.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
5252
AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
5353
AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
5454
AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving"
55+
AQUA_TEI_CONTAINER_FAMILY = "odsc-tei-serving"
5556

5657

5758
class InferenceContainerParamType(str, metaclass=ExtendedEnumMeta):
@@ -80,3 +81,11 @@ class RqsAdditionalDetails(str, metaclass=ExtendedEnumMeta):
8081
MODEL_VERSION_SET_NAME = "modelVersionSetName"
8182
PROJECT_ID = "projectId"
8283
VERSION_LABEL = "versionLabel"
84+
85+
86+
class TextEmbeddingInferenceContainerParams(str, metaclass=ExtendedEnumMeta):
87+
"""Contains a subset of params that are required for enabling model deployment in OCI Data Science. More options
88+
are available at https://huggingface.co/docs/text-embeddings-inference/en/cli_arguments"""
89+
90+
MODEL_ID = "model-id"
91+
PORT = "port"

ads/aqua/common/utils.py

Lines changed: 83 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
InferenceContainerParamType,
3636
InferenceContainerType,
3737
RqsAdditionalDetails,
38+
TextEmbeddingInferenceContainerParams,
3839
)
3940
from ads.aqua.common.errors import (
4041
AquaFileNotFoundError,
@@ -51,6 +52,7 @@
5152
MODEL_BY_REFERENCE_OSS_PATH_KEY,
5253
SERVICE_MANAGED_CONTAINER_URI_SCHEME,
5354
SUPPORTED_FILE_FORMATS,
55+
TEI_CONTAINER_DEFAULT_HOST,
5456
TGI_INFERENCE_RESTRICTED_PARAMS,
5557
UNKNOWN,
5658
UNKNOWN_JSON_STR,
@@ -63,7 +65,12 @@
6365
from ads.common.object_storage_details import ObjectStorageDetails
6466
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
6567
from ads.common.utils import copy_file, get_console_link, upload_to_os
66-
from ads.config import AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, TENANCY_OCID
68+
from ads.config import (
69+
AQUA_MODEL_DEPLOYMENT_FOLDER,
70+
AQUA_SERVICE_MODELS_BUCKET,
71+
CONDA_BUCKET_NS,
72+
TENANCY_OCID,
73+
)
6774
from ads.model import DataScienceModel, ModelVersionSet
6875

6976
logger = logging.getLogger("ads.aqua")
@@ -569,15 +576,13 @@ def get_container_image(
569576
A dict of allowed configs.
570577
"""
571578

579+
container_image = UNKNOWN
572580
config = config_file_name or get_container_config()
573581
config_file_name = service_config_path()
574582

575583
if container_type not in config:
576-
raise AquaValueError(
577-
f"{config_file_name} does not have config details for model: {container_type}"
578-
)
584+
return UNKNOWN
579585

580-
container_image = None
581586
mapping = config[container_type]
582587
versions = [obj["version"] for obj in mapping]
583588
# assumes numbered versions, update if `latest` is used
@@ -1078,3 +1083,76 @@ def list_hf_models(query: str) -> List[str]:
10781083
return [model.id for model in models if model.disabled is None]
10791084
except HfHubHTTPError as err:
10801085
raise format_hf_custom_error_message(err) from err
1086+
1087+
1088+
def generate_tei_cmd_var(os_path: str) -> List[str]:
1089+
"""This utility functions generates CMD params for Text Embedding Inference container. Only the
1090+
essential parameters for OCI model deployment are added, defaults are used for the rest.
1091+
Parameters
1092+
----------
1093+
os_path: str
1094+
OCI bucket path where the model artifacts are uploaded - oci://bucket@namespace/prefix
1095+
1096+
Returns
1097+
-------
1098+
cmd_var:
1099+
List of command line arguments
1100+
"""
1101+
1102+
cmd_prefix = "--"
1103+
cmd_var = [
1104+
f"{cmd_prefix}{TextEmbeddingInferenceContainerParams.MODEL_ID}",
1105+
f"{AQUA_MODEL_DEPLOYMENT_FOLDER}{ObjectStorageDetails.from_path(os_path.rstrip('/')).filepath}/",
1106+
f"{cmd_prefix}{TextEmbeddingInferenceContainerParams.PORT}",
1107+
TEI_CONTAINER_DEFAULT_HOST,
1108+
]
1109+
1110+
return cmd_var
1111+
1112+
1113+
def parse_cmd_var(cmd_list: List[str]) -> dict:
1114+
"""Helper functions that parses a list into a key-value dictionary. The list contains keys separated by the prefix
1115+
'--' and the value of the key is the subsequent element.
1116+
"""
1117+
parsed_cmd = {}
1118+
1119+
for i, cmd in enumerate(cmd_list):
1120+
if cmd.startswith("--"):
1121+
if i + 1 < len(cmd_list) and not cmd_list[i + 1].startswith("--"):
1122+
parsed_cmd[cmd] = cmd_list[i + 1]
1123+
i += 1
1124+
else:
1125+
parsed_cmd[cmd] = None
1126+
return parsed_cmd
1127+
1128+
1129+
def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]:
1130+
"""This function accepts two lists of parameters and combines them. If the second list shares the common parameter
1131+
names/keys, then it raises an error.
1132+
Parameters
1133+
----------
1134+
cmd_var: List[str]
1135+
Default list of parameters
1136+
overrides: List[str]
1137+
List of parameters to override
1138+
Returns
1139+
-------
1140+
List[str] of combined parameters
1141+
"""
1142+
cmd_var = [str(x) for x in cmd_var]
1143+
if not overrides:
1144+
return cmd_var
1145+
overrides = [str(x) for x in overrides]
1146+
1147+
cmd_dict = parse_cmd_var(cmd_var)
1148+
overrides_dict = parse_cmd_var(overrides)
1149+
1150+
# check for conflicts
1151+
common_keys = set(cmd_dict.keys()) & set(overrides_dict.keys())
1152+
if common_keys:
1153+
raise AquaValueError(
1154+
f"The following CMD input cannot be overridden for model deployment: {', '.join(common_keys)}"
1155+
)
1156+
1157+
combined_cmd_var = cmd_var + overrides
1158+
return combined_cmd_var

ads/aqua/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,4 @@
8080
"--port",
8181
"--host",
8282
}
83+
TEI_CONTAINER_DEFAULT_HOST = "8080"

ads/aqua/evaluation/entities.py

Lines changed: 45 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,18 @@
99
This module contains dataclasses for aqua evaluation.
1010
"""
1111

12-
from dataclasses import dataclass, field
13-
from typing import List, Optional, Union
12+
from pydantic import Field
13+
from typing import Any, Dict, List, Optional, Union
1414

1515
from ads.aqua.data import AquaResourceIdentifier
16-
from ads.common.serializer import DataClassSerializable
16+
from ads.aqua.config.utils.serializer import Serializable
1717

1818

19-
@dataclass(repr=False)
20-
class CreateAquaEvaluationDetails(DataClassSerializable):
21-
"""Dataclass to create aqua model evaluation.
19+
class CreateAquaEvaluationDetails(Serializable):
20+
"""Class for creating aqua model evaluation.
2221
23-
Fields
24-
------
22+
Properties
23+
----------
2524
evaluation_source_id: str
2625
The evaluation source id. Must be either model or model deployment ocid.
2726
evaluation_name: str
@@ -83,69 +82,64 @@ class CreateAquaEvaluationDetails(DataClassSerializable):
8382
ocpus: Optional[float] = None
8483
log_group_id: Optional[str] = None
8584
log_id: Optional[str] = None
86-
metrics: Optional[List] = None
85+
metrics: Optional[List[str]] = None
8786
force_overwrite: Optional[bool] = False
8887

88+
class Config:
89+
extra = "ignore"
8990

90-
@dataclass(repr=False)
91-
class AquaEvalReport(DataClassSerializable):
91+
class AquaEvalReport(Serializable):
9292
evaluation_id: str = ""
9393
content: str = ""
9494

95+
class Config:
96+
extra = "ignore"
9597

96-
@dataclass(repr=False)
97-
class ModelParams(DataClassSerializable):
98-
max_tokens: str = ""
99-
top_p: str = ""
100-
top_k: str = ""
101-
temperature: str = ""
102-
presence_penalty: Optional[float] = 0.0
103-
frequency_penalty: Optional[float] = 0.0
104-
stop: Optional[Union[str, List[str]]] = field(default_factory=list)
105-
model: Optional[str] = "odsc-llm"
106-
107-
108-
@dataclass(repr=False)
109-
class AquaEvalParams(ModelParams, DataClassSerializable):
98+
class AquaEvalParams(Serializable):
11099
shape: str = ""
111100
dataset_path: str = ""
112101
report_path: str = ""
113102

103+
class Config:
104+
extra = "allow"
114105

115-
@dataclass(repr=False)
116-
class AquaEvalMetric(DataClassSerializable):
106+
class AquaEvalMetric(Serializable):
117107
key: str
118108
name: str
119109
description: str = ""
120110

111+
class Config:
112+
extra = "ignore"
121113

122-
@dataclass(repr=False)
123-
class AquaEvalMetricSummary(DataClassSerializable):
114+
class AquaEvalMetricSummary(Serializable):
124115
metric: str = ""
125116
score: str = ""
126117
grade: str = ""
127118

119+
class Config:
120+
extra = "ignore"
128121

129-
@dataclass(repr=False)
130-
class AquaEvalMetrics(DataClassSerializable):
122+
class AquaEvalMetrics(Serializable):
131123
id: str
132124
report: str
133-
metric_results: List[AquaEvalMetric] = field(default_factory=list)
134-
metric_summary_result: List[AquaEvalMetricSummary] = field(default_factory=list)
125+
metric_results: List[AquaEvalMetric] = Field(default_factory=list)
126+
metric_summary_result: List[AquaEvalMetricSummary] = Field(default_factory=list)
135127

128+
class Config:
129+
extra = "ignore"
136130

137-
@dataclass(repr=False)
138-
class AquaEvaluationCommands(DataClassSerializable):
131+
class AquaEvaluationCommands(Serializable):
139132
evaluation_id: str
140133
evaluation_target_id: str
141-
input_data: dict
142-
metrics: list
134+
input_data: Dict[str, Any]
135+
metrics: List[str]
143136
output_dir: str
144-
params: dict
137+
params: Dict[str, Any]
145138

139+
class Config:
140+
extra = "ignore"
146141

147-
@dataclass(repr=False)
148-
class AquaEvaluationSummary(DataClassSerializable):
142+
class AquaEvaluationSummary(Serializable):
149143
"""Represents a summary of Aqua evalution."""
150144

151145
id: str
@@ -154,17 +148,18 @@ class AquaEvaluationSummary(DataClassSerializable):
154148
lifecycle_state: str
155149
lifecycle_details: str
156150
time_created: str
157-
tags: dict
158-
experiment: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
159-
source: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
160-
job: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
161-
parameters: AquaEvalParams = field(default_factory=AquaEvalParams)
151+
tags: Dict[str, Any]
152+
experiment: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
153+
source: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
154+
job: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
155+
parameters: AquaEvalParams = Field(default_factory=AquaEvalParams)
162156

157+
class Config:
158+
extra = "ignore"
163159

164-
@dataclass(repr=False)
165-
class AquaEvaluationDetail(AquaEvaluationSummary, DataClassSerializable):
160+
class AquaEvaluationDetail(AquaEvaluationSummary):
166161
"""Represents a details of Aqua evalution."""
167162

168-
log_group: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
169-
log: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
170-
introspection: dict = field(default_factory=dict)
163+
log_group: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
164+
log: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
165+
introspection: dict = Field(default_factory=dict)

0 commit comments

Comments
 (0)