Skip to content

Commit 063f5fb

Browse files
qiuosierliudmylaru
andauthored
Merge feature/aqua into main with all commits squashed. (#717)
Co-authored-by: lrudenka <liuda.rudenka@oracle.com>
1 parent 8f59ae1 commit 063f5fb

File tree

83 files changed

+10003
-341
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+10003
-341
lines changed

.github/workflows/run-unittests-py38-cov-report.yml

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@ jobs:
3434
test-path: "tests/unitary"
3535
# `model` tests running in "slow_tests",
3636
# `feature_store` tests has its own test suite
37+
# `forecast` tests has its own test suite
3738
# 'hpo' tests hangs if run together with all unitary tests. Tests running in separate command before running all unitary
3839
ignore-path: |
3940
--ignore tests/unitary/with_extras/model \
4041
--ignore tests/unitary/with_extras/feature_store \
42+
--ignore tests/unitary/with_extras/operator/forecast \
4143
--ignore tests/unitary/with_extras/hpo
4244
- name: "slow_tests"
4345
test-path: "tests/unitary/with_extras/model"
@@ -63,11 +65,6 @@ jobs:
6365
name: "Test env setup"
6466
timeout-minutes: 30
6567

66-
# Installing forecast deps for python3.8 test setup only, it will not work with python3.9/3.10, because
67-
# automlx do not support py3.9 and some versions of py3.10. This step omitted in -py39-py30.yml workflow
68-
- name: "Install Forecasting dependencies"
69-
run: |
70-
pip install -e ".[forecast]"
7168
# Installing pii deps for python3.8 test setup only, it will not work with python3.9/3.10, because
7269
# 'datapane' library conflicts with pandas>2.2.0, which used in py3.9/3.10 setup
7370
- name: "Install PII dependencies"

THIRD_PARTY_LICENSES.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,12 @@ delta
405405
* Source code: https://github.com/delta-io/delta/
406406
* Project home: https://delta.io/
407407

408+
cachetools
409+
* Copyright (c) 2014-2024 Thomas Kemmer
410+
* License: The MIT License (MIT)
411+
* Source code: https://github.com/tkem/cachetools/
412+
* Project home: https://cachetools.readthedocs.io/
413+
408414
=============================== Licenses ===============================
409415
------------------------------------------------------------------------
410416

ads/aqua/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright (c) 2024 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
6+
7+
import logging
8+
import sys
9+
10+
logger = logging.getLogger(__name__)
11+
handler = logging.StreamHandler(sys.stdout)
12+
logger.setLevel(logging.INFO)

ads/aqua/base.py

Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright (c) 2024 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
6+
import os
7+
from typing import Dict, Union
8+
9+
import oci
10+
from oci.data_science.models import UpdateModelDetails, UpdateModelProvenanceDetails
11+
12+
from ads import set_auth
13+
from ads.aqua import logger
14+
from ads.aqua.data import Tags
15+
from ads.aqua.exception import AquaRuntimeError, AquaValueError
16+
from ads.aqua.utils import (
17+
UNKNOWN,
18+
_is_valid_mvs,
19+
get_artifact_path,
20+
get_base_model_from_tags,
21+
is_valid_ocid,
22+
load_config,
23+
logger,
24+
)
25+
from ads.common import oci_client as oc
26+
from ads.common.auth import default_signer
27+
from ads.common.utils import extract_region
28+
from ads.config import (
29+
AQUA_TELEMETRY_BUCKET,
30+
AQUA_TELEMETRY_BUCKET_NS,
31+
OCI_ODSC_SERVICE_ENDPOINT,
32+
OCI_RESOURCE_PRINCIPAL_VERSION,
33+
)
34+
from ads.model.datascience_model import DataScienceModel
35+
from ads.model.deployment.model_deployment import ModelDeployment
36+
from ads.model.model_metadata import (
37+
ModelCustomMetadata,
38+
ModelProvenanceMetadata,
39+
ModelTaxonomyMetadata,
40+
)
41+
from ads.model.model_version_set import ModelVersionSet
42+
from ads.telemetry import telemetry
43+
from ads.telemetry.client import TelemetryClient
44+
45+
46+
class AquaApp:
47+
"""Base Aqua App to contain common components."""
48+
49+
@telemetry(name="aqua")
50+
def __init__(self) -> None:
51+
if OCI_RESOURCE_PRINCIPAL_VERSION:
52+
set_auth("resource_principal")
53+
self._auth = default_signer({"service_endpoint": OCI_ODSC_SERVICE_ENDPOINT})
54+
self.ds_client = oc.OCIClientFactory(**self._auth).data_science
55+
self.logging_client = oc.OCIClientFactory(**default_signer()).logging_management
56+
self.identity_client = oc.OCIClientFactory(**default_signer()).identity
57+
self.region = extract_region(self._auth)
58+
self._telemetry = None
59+
60+
def list_resource(
61+
self,
62+
list_func_ref,
63+
**kwargs,
64+
) -> list:
65+
"""Generic method to list OCI Data Science resources.
66+
67+
Parameters
68+
----------
69+
list_func_ref : function
70+
A reference to the list operation which will be called.
71+
**kwargs :
72+
Additional keyword arguments to filter the resource.
73+
The kwargs are passed into OCI API.
74+
75+
Returns
76+
-------
77+
list
78+
A list of OCI Data Science resources.
79+
"""
80+
return oci.pagination.list_call_get_all_results(
81+
list_func_ref,
82+
**kwargs,
83+
).data
84+
85+
def update_model(self, model_id: str, update_model_details: UpdateModelDetails):
86+
"""Updates model details.
87+
88+
Parameters
89+
----------
90+
model_id : str
91+
The id of target model.
92+
update_model_details: UpdateModelDetails
93+
The model details to be updated.
94+
"""
95+
self.ds_client.update_model(
96+
model_id=model_id, update_model_details=update_model_details
97+
)
98+
99+
def update_model_provenance(
100+
self,
101+
model_id: str,
102+
update_model_provenance_details: UpdateModelProvenanceDetails,
103+
):
104+
"""Updates model provenance details.
105+
106+
Parameters
107+
----------
108+
model_id : str
109+
The id of target model.
110+
update_model_provenance_details: UpdateModelProvenanceDetails
111+
The model provenance details to be updated.
112+
"""
113+
self.ds_client.update_model_provenance(
114+
model_id=model_id,
115+
update_model_provenance_details=update_model_provenance_details,
116+
)
117+
118+
# TODO: refactor model evaluation implementation to use it.
119+
@staticmethod
120+
def get_source(source_id: str) -> Union[ModelDeployment, DataScienceModel]:
121+
if is_valid_ocid(source_id):
122+
if "datasciencemodeldeployment" in source_id:
123+
return ModelDeployment.from_id(source_id)
124+
elif "datasciencemodel" in source_id:
125+
return DataScienceModel.from_id(source_id)
126+
127+
raise AquaValueError(
128+
f"Invalid source {source_id}. "
129+
"Specify either a model or model deployment id."
130+
)
131+
132+
# TODO: refactor model evaluation implementation to use it.
133+
@staticmethod
134+
def create_model_version_set(
135+
model_version_set_id: str = None,
136+
model_version_set_name: str = None,
137+
description: str = None,
138+
compartment_id: str = None,
139+
project_id: str = None,
140+
**kwargs,
141+
) -> tuple:
142+
"""Creates ModelVersionSet from given ID or Name.
143+
144+
Parameters
145+
----------
146+
model_version_set_id: (str, optional):
147+
ModelVersionSet OCID.
148+
model_version_set_name: (str, optional):
149+
ModelVersionSet Name.
150+
description: (str, optional):
151+
TBD
152+
compartment_id: (str, optional):
153+
Compartment OCID.
154+
project_id: (str, optional):
155+
Project OCID.
156+
tag: (str, optional)
157+
calling tag, can be Tags.AQUA_FINE_TUNING or Tags.AQUA_EVALUATION
158+
159+
Returns
160+
-------
161+
tuple: (model_version_set_id, model_version_set_name)
162+
"""
163+
# TODO: tag should be selected based on which operation (eval/FT) invoke this method
164+
# currently only used by fine-tuning flow.
165+
tag = Tags.AQUA_FINE_TUNING.value
166+
167+
if not model_version_set_id:
168+
tag = Tags.AQUA_FINE_TUNING.value # TODO: Fix this
169+
try:
170+
model_version_set = ModelVersionSet.from_name(
171+
name=model_version_set_name,
172+
compartment_id=compartment_id,
173+
)
174+
175+
if not _is_valid_mvs(model_version_set, tag):
176+
raise AquaValueError(
177+
f"Invalid model version set name. Please provide a model version set with `{tag}` in tags."
178+
)
179+
180+
except:
181+
logger.debug(
182+
f"Model version set {model_version_set_name} doesn't exist. "
183+
"Creating new model version set."
184+
)
185+
mvs_freeform_tags = {
186+
tag: tag,
187+
}
188+
model_version_set = (
189+
ModelVersionSet()
190+
.with_compartment_id(compartment_id)
191+
.with_project_id(project_id)
192+
.with_name(model_version_set_name)
193+
.with_description(description)
194+
.with_freeform_tags(**mvs_freeform_tags)
195+
# TODO: decide what parameters will be needed
196+
# when refactor eval to use this method, we need to pass tag here.
197+
.create(**kwargs)
198+
)
199+
logger.debug(
200+
f"Successfully created model version set {model_version_set_name} with id {model_version_set.id}."
201+
)
202+
return (model_version_set.id, model_version_set_name)
203+
else:
204+
model_version_set = ModelVersionSet.from_id(model_version_set_id)
205+
# TODO: tag should be selected based on which operation (eval/FT) invoke this method
206+
if not _is_valid_mvs(model_version_set, tag):
207+
raise AquaValueError(
208+
f"Invalid model version set id. Please provide a model version set with `{tag}` in tags."
209+
)
210+
return (model_version_set_id, model_version_set.name)
211+
212+
# TODO: refactor model evaluation implementation to use it.
213+
@staticmethod
214+
def create_model_catalog(
215+
display_name: str,
216+
description: str,
217+
model_version_set_id: str,
218+
model_custom_metadata: Union[ModelCustomMetadata, Dict],
219+
model_taxonomy_metadata: Union[ModelTaxonomyMetadata, Dict],
220+
compartment_id: str,
221+
project_id: str,
222+
**kwargs,
223+
) -> DataScienceModel:
224+
model = (
225+
DataScienceModel()
226+
.with_compartment_id(compartment_id)
227+
.with_project_id(project_id)
228+
.with_display_name(display_name)
229+
.with_description(description)
230+
.with_model_version_set_id(model_version_set_id)
231+
.with_custom_metadata_list(model_custom_metadata)
232+
.with_defined_metadata_list(model_taxonomy_metadata)
233+
.with_provenance_metadata(ModelProvenanceMetadata(training_id=UNKNOWN))
234+
# TODO: decide what parameters will be needed
235+
.create(
236+
**kwargs,
237+
)
238+
)
239+
return model
240+
241+
def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
242+
"""Checks if the artifact exists.
243+
244+
Parameters
245+
----------
246+
model_id : str
247+
The model OCID.
248+
**kwargs :
249+
Additional keyword arguments passed in head_model_artifact.
250+
251+
Returns
252+
-------
253+
bool
254+
Whether the artifact exists.
255+
"""
256+
257+
try:
258+
response = self.ds_client.head_model_artifact(model_id=model_id, **kwargs)
259+
return True if response.status == 200 else False
260+
except oci.exceptions.ServiceError as ex:
261+
if ex.status == 404:
262+
logger.info(f"Artifact not found in model {model_id}.")
263+
return False
264+
265+
def get_config(self, model_id: str, config_file_name: str) -> Dict:
266+
"""Gets the config for the given Aqua model.
267+
268+
Parameters
269+
----------
270+
model_id: str
271+
The OCID of the Aqua model.
272+
config_file_name: str
273+
name of the config file
274+
275+
Returns
276+
-------
277+
Dict:
278+
A dict of allowed configs.
279+
"""
280+
oci_model = self.ds_client.get_model(model_id).data
281+
oci_aqua = (
282+
(
283+
Tags.AQUA_TAG.value in oci_model.freeform_tags
284+
or Tags.AQUA_TAG.value.lower() in oci_model.freeform_tags
285+
)
286+
if oci_model.freeform_tags
287+
else False
288+
)
289+
290+
if not oci_aqua:
291+
raise AquaRuntimeError(f"Target model {oci_model.id} is not Aqua model.")
292+
293+
config = {}
294+
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
295+
if not artifact_path:
296+
logger.error(
297+
f"Failed to get artifact path from custom metadata for the model: {model_id}"
298+
)
299+
return config
300+
301+
try:
302+
config_path = f"{os.path.dirname(artifact_path)}/config/"
303+
config = load_config(
304+
config_path,
305+
config_file_name=config_file_name,
306+
)
307+
except:
308+
pass
309+
310+
if not config:
311+
logger.error(
312+
f"{config_file_name} is not available for the model: {model_id}. Check if the custom metadata has the artifact path set."
313+
)
314+
return config
315+
316+
return config
317+
318+
@property
319+
def telemetry(self):
320+
if not self._telemetry:
321+
self._telemetry = TelemetryClient(
322+
bucket=AQUA_TELEMETRY_BUCKET, namespace=AQUA_TELEMETRY_BUCKET_NS
323+
)
324+
return self._telemetry

ads/aqua/cli.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2024 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
from ads.aqua.deployment import AquaDeploymentApp
8+
from ads.aqua.finetune import AquaFineTuningApp
9+
from ads.aqua.model import AquaModelApp
10+
from ads.aqua.evaluation import AquaEvaluationApp
11+
12+
13+
class AquaCommand:
14+
"""Contains the command groups for project Aqua."""
15+
16+
model = AquaModelApp
17+
fine_tuning = AquaFineTuningApp
18+
deployment = AquaDeploymentApp
19+
evaluation = AquaEvaluationApp

0 commit comments

Comments
 (0)