Skip to content

Commit 5c8f615

Browse files
committed
Merge branch 'main' of https://github.com/oracle/accelerated-data-science into ODSC-63978/onnx_model_support
2 parents 4681ead + eab7c47 commit 5c8f615

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1242
-315
lines changed

THIRD_PARTY_LICENSES.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ autots
2424
* Source code: https://github.com/winedarksea/AutoTS
2525
* Project home: https://winedarksea.github.io/AutoTS/build/html/index.html
2626

27+
autogen
28+
* Copyright (c) 2024 Microsoft Corporation.
29+
* License: MIT License
30+
* Source code: https://github.com/microsoft/autogen
31+
* Project home: microsoft.github.io/autogen/
32+
2733
bokeh
2834
* Copyright Copyright (c) 2012 - 2021, Anaconda, Inc., and Bokeh Contributors
2935
* License: BSD 3-Clause "New" or "Revised" License

ads/aqua/app.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Copyright (c) 2024 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

5+
import json
56
import os
67
from dataclasses import fields
78
from typing import Dict, Union
@@ -135,6 +136,8 @@ def create_model_version_set(
135136
description: str = None,
136137
compartment_id: str = None,
137138
project_id: str = None,
139+
freeform_tags: dict = None,
140+
defined_tags: dict = None,
138141
**kwargs,
139142
) -> tuple:
140143
"""Creates ModelVersionSet from given ID or Name.
@@ -153,7 +156,10 @@ def create_model_version_set(
153156
Project OCID.
154157
tag: (str, optional)
155158
calling tag, can be Tags.AQUA_FINE_TUNING or Tags.AQUA_EVALUATION
156-
159+
freeform_tags: (dict, optional)
160+
Freeform tags for the model version set
161+
defined_tags: (dict, optional)
162+
Defined tags for the model version set
157163
Returns
158164
-------
159165
tuple: (model_version_set_id, model_version_set_name)
@@ -182,13 +188,15 @@ def create_model_version_set(
182188
mvs_freeform_tags = {
183189
tag: tag,
184190
}
191+
mvs_freeform_tags = {**mvs_freeform_tags, **(freeform_tags or {})}
185192
model_version_set = (
186193
ModelVersionSet()
187194
.with_compartment_id(compartment_id)
188195
.with_project_id(project_id)
189196
.with_name(model_version_set_name)
190197
.with_description(description)
191198
.with_freeform_tags(**mvs_freeform_tags)
199+
.with_defined_tags(**(defined_tags or {}))
192200
# TODO: decide what parameters will be needed
193201
# when refactor eval to use this method, we need to pass tag here.
194202
.create(**kwargs)
@@ -340,7 +348,9 @@ def build_cli(self) -> str:
340348
"""
341349
cmd = f"ads aqua {self._command}"
342350
params = [
343-
f"--{field.name} {getattr(self,field.name)}"
351+
f"--{field.name} {json.dumps(getattr(self, field.name))}"
352+
if isinstance(getattr(self, field.name), dict)
353+
else f"--{field.name} {getattr(self, field.name)}"
344354
for field in fields(self.__class__)
345355
if getattr(self, field.name) is not None
346356
]

ads/aqua/common/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -788,13 +788,14 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
788788
return ocid[-key_len:] if ocid and len(ocid) > key_len else ""
789789

790790

791-
def upload_folder(os_path: str, local_dir: str, model_name: str) -> str:
791+
def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None) -> str:
792792
"""Upload the local folder to the object storage
793793
794794
Args:
795795
os_path (str): object storage URI with prefix. This is the path to upload
796796
local_dir (str): Local directory where the object is downloaded
797797
model_name (str): Name of the huggingface model
798+
exclude_pattern (optional, str): The matching pattern of files to be excluded from uploading.
798799
Retuns:
799800
str: Object name inside the bucket
800801
"""
@@ -804,6 +805,8 @@ def upload_folder(os_path: str, local_dir: str, model_name: str) -> str:
804805
auth_state = AuthState()
805806
object_path = os_details.filepath.rstrip("/") + "/" + model_name + "/"
806807
command = f"oci os object bulk-upload --src-dir {local_dir} --prefix {object_path} -bn {os_details.bucket} -ns {os_details.namespace} --auth {auth_state.oci_iam_type} --profile {auth_state.oci_key_profile} --no-overwrite"
808+
if exclude_pattern:
809+
command += f" --exclude {exclude_pattern}"
807810
try:
808811
logger.info(f"Running: {command}")
809812
subprocess.check_call(shlex.split(command))

ads/aqua/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME = "_name_or_path"
3636
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE = "model_type"
3737
AQUA_MODEL_ARTIFACT_FILE = "model_file"
38+
HF_METADATA_FOLDER = ".cache/"
3839
HF_LOGIN_DEFAULT_TIMEOUT = 2
3940

4041
TRAINING_METRICS_FINAL = "training_metrics_final"

ads/aqua/evaluation/entities.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ class CreateAquaEvaluationDetails(Serializable):
6464
The metrics for the evaluation.
6565
force_overwrite: (bool, optional). Defaults to `False`.
6666
Whether to force overwrite the existing file in object storage.
67+
freeform_tags: (dict, optional)
68+
Freeform tags for the evaluation model
69+
defined_tags: (dict, optional)
70+
Defined tags for the evaluation model
6771
"""
6872

6973
evaluation_source_id: str
@@ -83,8 +87,10 @@ class CreateAquaEvaluationDetails(Serializable):
8387
ocpus: Optional[float] = None
8488
log_group_id: Optional[str] = None
8589
log_id: Optional[str] = None
86-
metrics: Optional[List[str]] = None
90+
metrics: Optional[List[Dict[str, Any]]] = None
8791
force_overwrite: Optional[bool] = False
92+
freeform_tags: Optional[dict] = None
93+
defined_tags: Optional[dict] = None
8894

8995
class Config:
9096
extra = "ignore"
@@ -140,7 +146,7 @@ class AquaEvaluationCommands(Serializable):
140146
evaluation_id: str
141147
evaluation_target_id: str
142148
input_data: Dict[str, Any]
143-
metrics: List[str]
149+
metrics: List[Dict[str, Any]]
144150
output_dir: str
145151
params: Dict[str, Any]
146152

ads/aqua/evaluation/evaluation.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ def create(
159159
create_aqua_evaluation_details = CreateAquaEvaluationDetails(**kwargs)
160160
except Exception as ex:
161161
custom_errors = {
162-
".".join(map(str, e["loc"])): e["msg"] for e in json.loads(ex.json())
162+
".".join(map(str, e["loc"])): e["msg"]
163+
for e in json.loads(ex.json())
163164
}
164165
raise AquaValueError(
165166
f"Invalid create evaluation parameters. Error details: {custom_errors}."
@@ -296,6 +297,10 @@ def create(
296297
evaluation_mvs_freeform_tags = {
297298
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
298299
}
300+
evaluation_mvs_freeform_tags = {
301+
**evaluation_mvs_freeform_tags,
302+
**(create_aqua_evaluation_details.freeform_tags or {}),
303+
}
299304

300305
model_version_set = (
301306
ModelVersionSet()
@@ -306,6 +311,9 @@ def create(
306311
create_aqua_evaluation_details.experiment_description
307312
)
308313
.with_freeform_tags(**evaluation_mvs_freeform_tags)
314+
.with_defined_tags(
315+
**(create_aqua_evaluation_details.defined_tags or {})
316+
)
309317
# TODO: decide what parameters will be needed
310318
.create(**kwargs)
311319
)
@@ -368,6 +376,10 @@ def create(
368376
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
369377
Tags.AQUA_EVALUATION_MODEL_ID: evaluation_model.id,
370378
}
379+
evaluation_job_freeform_tags = {
380+
**evaluation_job_freeform_tags,
381+
**(create_aqua_evaluation_details.freeform_tags or {}),
382+
}
371383

372384
evaluation_job = Job(name=evaluation_model.display_name).with_infrastructure(
373385
DataScienceJob()
@@ -378,6 +390,7 @@ def create(
378390
.with_shape_name(create_aqua_evaluation_details.shape_name)
379391
.with_block_storage_size(create_aqua_evaluation_details.block_storage_size)
380392
.with_freeform_tag(**evaluation_job_freeform_tags)
393+
.with_defined_tag(**(create_aqua_evaluation_details.defined_tags or {}))
381394
)
382395
if (
383396
create_aqua_evaluation_details.memory_in_gbs
@@ -424,6 +437,7 @@ def create(
424437
evaluation_job_run = evaluation_job.run(
425438
name=evaluation_model.display_name,
426439
freeform_tags=evaluation_job_freeform_tags,
440+
defined_tags=(create_aqua_evaluation_details.defined_tags or {}),
427441
wait=False,
428442
)
429443
logger.debug(
@@ -443,13 +457,20 @@ def create(
443457
for metadata in evaluation_model_custom_metadata.to_dict()["data"]
444458
]
445459

460+
evaluation_model_freeform_tags = {
461+
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
462+
**(create_aqua_evaluation_details.freeform_tags or {}),
463+
}
464+
evaluation_model_defined_tags = (
465+
create_aqua_evaluation_details.defined_tags or {}
466+
)
467+
446468
self.ds_client.update_model(
447469
model_id=evaluation_model.id,
448470
update_model_details=UpdateModelDetails(
449471
custom_metadata_list=updated_custom_metadata_list,
450-
freeform_tags={
451-
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
452-
},
472+
freeform_tags=evaluation_model_freeform_tags,
473+
defined_tags=evaluation_model_defined_tags,
453474
),
454475
)
455476

@@ -523,6 +544,8 @@ def create(
523544
"evaluation_job_id": evaluation_job.id,
524545
"evaluation_source": create_aqua_evaluation_details.evaluation_source_id,
525546
"evaluation_experiment_id": experiment_model_version_set_id,
547+
**evaluation_model_freeform_tags,
548+
**evaluation_model_defined_tags,
526549
},
527550
parameters=AquaEvalParams(),
528551
)
@@ -619,11 +642,6 @@ def _build_launch_cmd(
619642
evaluation_id=evaluation_id,
620643
evaluation_target_id=evaluation_source_id,
621644
input_data={
622-
"columns": {
623-
"prompt": "prompt",
624-
"completion": "completion",
625-
"category": "category",
626-
},
627645
"format": Path(dataset_path).suffix,
628646
"url": dataset_path,
629647
},

ads/aqua/extension/deployment_handler.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def delete(self, model_deployment_id):
5959
return self.finish(AquaDeploymentApp().delete(model_deployment_id))
6060

6161
@handle_exceptions
62-
def put(self, *args, **kwargs):
62+
def put(self, *args, **kwargs): # noqa: ARG002
6363
"""
6464
Handles put request for the activating and deactivating OCI datascience model deployments
6565
Raises
@@ -82,7 +82,7 @@ def put(self, *args, **kwargs):
8282
raise HTTPError(400, f"The request {self.request.path} is invalid.")
8383

8484
@handle_exceptions
85-
def post(self, *args, **kwargs):
85+
def post(self, *args, **kwargs): # noqa: ARG002
8686
"""
8787
Handles post request for the deployment APIs
8888
Raises
@@ -132,6 +132,8 @@ def post(self, *args, **kwargs):
132132
private_endpoint_id = input_data.get("private_endpoint_id")
133133
container_image_uri = input_data.get("container_image_uri")
134134
cmd_var = input_data.get("cmd_var")
135+
freeform_tags = input_data.get("freeform_tags")
136+
defined_tags = input_data.get("defined_tags")
135137

136138
self.finish(
137139
AquaDeploymentApp().create(
@@ -157,6 +159,8 @@ def post(self, *args, **kwargs):
157159
private_endpoint_id=private_endpoint_id,
158160
container_image_uri=container_image_uri,
159161
cmd_var=cmd_var,
162+
freeform_tags=freeform_tags,
163+
defined_tags=defined_tags,
160164
)
161165
)
162166

@@ -196,7 +200,7 @@ def validate_predict_url(endpoint):
196200
return False
197201

198202
@handle_exceptions
199-
def post(self, *args, **kwargs):
203+
def post(self, *args, **kwargs): # noqa: ARG002
200204
"""
201205
Handles inference request for the Active Model Deployments
202206
Raises
@@ -262,7 +266,7 @@ def get(self, model_id):
262266
)
263267

264268
@handle_exceptions
265-
def post(self, *args, **kwargs):
269+
def post(self, *args, **kwargs): # noqa: ARG002
266270
"""Handles post request for the deployment param handler API.
267271
268272
Raises

ads/aqua/extension/model_handler.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def list(self):
9696
)
9797

9898
@handle_exceptions
99-
def post(self, *args, **kwargs):
99+
def post(self, *args, **kwargs): # noqa: ARG002
100100
"""
101101
Handles post request for the registering any Aqua model.
102102
Raises
@@ -129,6 +129,10 @@ def post(self, *args, **kwargs):
129129
str(input_data.get("download_from_hf", "false")).lower() == "true"
130130
)
131131
inference_container_uri = input_data.get("inference_container_uri")
132+
allow_patterns = input_data.get("allow_patterns")
133+
ignore_patterns = input_data.get("ignore_patterns")
134+
freeform_tags = input_data.get("freeform_tags")
135+
defined_tags = input_data.get("defined_tags")
132136

133137
return self.finish(
134138
AquaModelApp().register(
@@ -141,6 +145,10 @@ def post(self, *args, **kwargs):
141145
project_id=project_id,
142146
model_file=model_file,
143147
inference_container_uri=inference_container_uri,
148+
allow_patterns=allow_patterns,
149+
ignore_patterns=ignore_patterns,
150+
freeform_tags=freeform_tags,
151+
defined_tags=defined_tags,
144152
)
145153
)
146154

@@ -166,11 +174,9 @@ def put(self, id):
166174

167175
enable_finetuning = input_data.get("enable_finetuning")
168176
task = input_data.get("task")
169-
app=AquaModelApp()
177+
app = AquaModelApp()
170178
self.finish(
171-
app.edit_registered_model(
172-
id, inference_container, enable_finetuning, task
173-
)
179+
app.edit_registered_model(id, inference_container, enable_finetuning, task)
174180
)
175181
app.clear_model_details_cache(model_id=id)
176182

@@ -214,7 +220,7 @@ def _find_matching_aqua_model(model_id: str) -> Optional[AquaModelSummary]:
214220
return None
215221

216222
@handle_exceptions
217-
def get(self, *args, **kwargs):
223+
def get(self, *args, **kwargs): # noqa: ARG002
218224
"""
219225
Finds a list of matching models from hugging face based on query string provided from users.
220226
@@ -235,7 +241,7 @@ def get(self, *args, **kwargs):
235241
return self.finish({"models": models})
236242

237243
@handle_exceptions
238-
def post(self, *args, **kwargs):
244+
def post(self, *args, **kwargs): # noqa: ARG002
239245
"""Handles post request for the HF Models APIs
240246
241247
Raises

ads/aqua/extension/ui_handler.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def get(self, id=""):
6868
return self.list_buckets()
6969
elif paths.startswith("aqua/job/shapes"):
7070
return self.list_job_shapes()
71+
elif paths.startswith("aqua/modeldeployment/shapes"):
72+
return self.list_model_deployment_shapes()
7173
elif paths.startswith("aqua/vcn"):
7274
return self.list_vcn()
7375
elif paths.startswith("aqua/subnets"):
@@ -160,6 +162,15 @@ def list_job_shapes(self, **kwargs):
160162
AquaUIApp().list_job_shapes(compartment_id=compartment_id, **kwargs)
161163
)
162164

165+
def list_model_deployment_shapes(self, **kwargs):
166+
"""Lists model deployment shapes available in the specified compartment."""
167+
compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
168+
return self.finish(
169+
AquaUIApp().list_model_deployment_shapes(
170+
compartment_id=compartment_id, **kwargs
171+
)
172+
)
173+
163174
def list_vcn(self, **kwargs):
164175
"""Lists the virtual cloud networks (VCNs) in the specified compartment."""
165176
compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
@@ -255,8 +266,9 @@ def post(self, *args, **kwargs):
255266
__handlers__ = [
256267
("logging/?([^/]*)", AquaUIHandler),
257268
("compartments/?([^/]*)", AquaUIHandler),
258-
# TODO: change url to evaluation/experiements/?([^/]*)
269+
# TODO: change url to evaluation/experiments/?([^/]*)
259270
("experiment/?([^/]*)", AquaUIHandler),
271+
("modeldeployment/?([^/]*)", AquaUIHandler),
260272
("versionsets/?([^/]*)", AquaUIHandler),
261273
("buckets/?([^/]*)", AquaUIHandler),
262274
("job/shapes/?([^/]*)", AquaUIHandler),

0 commit comments

Comments
 (0)