Skip to content

Commit c87b1fa

Browse files
Merge branch 'main' into ODSC-64654-register-model-artifact-reference
2 parents 9223d68 + b643faa commit c87b1fa

File tree

83 files changed

+3198
-487
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+3198
-487
lines changed

.github/workflows/run-unittests-default_setup.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: "[Py3.8-3.11] - Default Tests"
1+
name: "[Py3.9-3.11] - Default Tests"
22

33
on:
44
workflow_dispatch:
@@ -33,7 +33,7 @@ jobs:
3333
strategy:
3434
fail-fast: false
3535
matrix:
36-
python-version: ["3.8", "3.9", "3.10", "3.11"]
36+
python-version: ["3.9", "3.10", "3.11"]
3737

3838
steps:
3939
- uses: actions/checkout@v4

.github/workflows/run-unittests-py39-py310.yml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,16 @@ jobs:
7474
name: "Test env setup"
7575
timeout-minutes: 30
7676

77-
- name: "Run hpo tests"
78-
timeout-minutes: 10
79-
shell: bash
80-
if: ${{ matrix.name }} == "unitary"
81-
run: |
82-
set -x # print commands that are executed
77+
# - name: "Run hpo tests"
78+
# timeout-minutes: 10
79+
# shell: bash
80+
# if: ${{ matrix.name }} == "unitary"
81+
# run: |
82+
# set -x # print commands that are executed
8383

84-
# Run hpo tests, which hangs if run together with all unitary tests
85-
python -m pytest -v -p no:warnings -n auto --dist loadfile \
86-
tests/unitary/with_extras/hpo
84+
# # Run hpo tests, which hangs if run together with all unitary tests
85+
# python -m pytest -v -p no:warnings -n auto --dist loadfile \
86+
# tests/unitary/with_extras/hpo
8787

8888
- name: "Run unitary tests folder with maximum ADS dependencies"
8989
timeout-minutes: 60

ads/aqua/app.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Copyright (c) 2024 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

5+
import json
56
import os
67
from dataclasses import fields
78
from typing import Dict, Union
@@ -135,6 +136,8 @@ def create_model_version_set(
135136
description: str = None,
136137
compartment_id: str = None,
137138
project_id: str = None,
139+
freeform_tags: dict = None,
140+
defined_tags: dict = None,
138141
**kwargs,
139142
) -> tuple:
140143
"""Creates ModelVersionSet from given ID or Name.
@@ -153,7 +156,10 @@ def create_model_version_set(
153156
Project OCID.
154157
tag: (str, optional)
155158
calling tag, can be Tags.AQUA_FINE_TUNING or Tags.AQUA_EVALUATION
156-
159+
freeform_tags: (dict, optional)
160+
Freeform tags for the model version set
161+
defined_tags: (dict, optional)
162+
Defined tags for the model version set
157163
Returns
158164
-------
159165
tuple: (model_version_set_id, model_version_set_name)
@@ -182,13 +188,15 @@ def create_model_version_set(
182188
mvs_freeform_tags = {
183189
tag: tag,
184190
}
191+
mvs_freeform_tags = {**mvs_freeform_tags, **(freeform_tags or {})}
185192
model_version_set = (
186193
ModelVersionSet()
187194
.with_compartment_id(compartment_id)
188195
.with_project_id(project_id)
189196
.with_name(model_version_set_name)
190197
.with_description(description)
191198
.with_freeform_tags(**mvs_freeform_tags)
199+
.with_defined_tags(**(defined_tags or {}))
192200
# TODO: decide what parameters will be needed
193201
# when refactor eval to use this method, we need to pass tag here.
194202
.create(**kwargs)
@@ -340,7 +348,9 @@ def build_cli(self) -> str:
340348
"""
341349
cmd = f"ads aqua {self._command}"
342350
params = [
343-
f"--{field.name} {getattr(self,field.name)}"
351+
f"--{field.name} {json.dumps(getattr(self, field.name))}"
352+
if isinstance(getattr(self, field.name), dict)
353+
else f"--{field.name} {getattr(self, field.name)}"
344354
for field in fields(self.__class__)
345355
if getattr(self, field.name) is not None
346356
]

ads/aqua/evaluation/entities.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ class CreateAquaEvaluationDetails(Serializable):
6464
The metrics for the evaluation.
6565
force_overwrite: (bool, optional). Defaults to `False`.
6666
Whether to force overwrite the existing file in object storage.
67+
freeform_tags: (dict, optional)
68+
Freeform tags for the evaluation model
69+
defined_tags: (dict, optional)
70+
Defined tags for the evaluation model
6771
"""
6872

6973
evaluation_source_id: str
@@ -83,8 +87,10 @@ class CreateAquaEvaluationDetails(Serializable):
8387
ocpus: Optional[float] = None
8488
log_group_id: Optional[str] = None
8589
log_id: Optional[str] = None
86-
metrics: Optional[List[str]] = None
90+
metrics: Optional[List[Dict[str, Any]]] = None
8791
force_overwrite: Optional[bool] = False
92+
freeform_tags: Optional[dict] = None
93+
defined_tags: Optional[dict] = None
8894

8995
class Config:
9096
extra = "ignore"
@@ -140,7 +146,7 @@ class AquaEvaluationCommands(Serializable):
140146
evaluation_id: str
141147
evaluation_target_id: str
142148
input_data: Dict[str, Any]
143-
metrics: List[str]
149+
metrics: List[Dict[str, Any]]
144150
output_dir: str
145151
params: Dict[str, Any]
146152

ads/aqua/evaluation/evaluation.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ def create(
159159
create_aqua_evaluation_details = CreateAquaEvaluationDetails(**kwargs)
160160
except Exception as ex:
161161
custom_errors = {
162-
".".join(map(str, e["loc"])): e["msg"] for e in json.loads(ex.json())
162+
".".join(map(str, e["loc"])): e["msg"]
163+
for e in json.loads(ex.json())
163164
}
164165
raise AquaValueError(
165166
f"Invalid create evaluation parameters. Error details: {custom_errors}."
@@ -296,6 +297,10 @@ def create(
296297
evaluation_mvs_freeform_tags = {
297298
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
298299
}
300+
evaluation_mvs_freeform_tags = {
301+
**evaluation_mvs_freeform_tags,
302+
**(create_aqua_evaluation_details.freeform_tags or {}),
303+
}
299304

300305
model_version_set = (
301306
ModelVersionSet()
@@ -306,6 +311,9 @@ def create(
306311
create_aqua_evaluation_details.experiment_description
307312
)
308313
.with_freeform_tags(**evaluation_mvs_freeform_tags)
314+
.with_defined_tags(
315+
**(create_aqua_evaluation_details.defined_tags or {})
316+
)
309317
# TODO: decide what parameters will be needed
310318
.create(**kwargs)
311319
)
@@ -368,6 +376,10 @@ def create(
368376
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
369377
Tags.AQUA_EVALUATION_MODEL_ID: evaluation_model.id,
370378
}
379+
evaluation_job_freeform_tags = {
380+
**evaluation_job_freeform_tags,
381+
**(create_aqua_evaluation_details.freeform_tags or {}),
382+
}
371383

372384
evaluation_job = Job(name=evaluation_model.display_name).with_infrastructure(
373385
DataScienceJob()
@@ -378,6 +390,7 @@ def create(
378390
.with_shape_name(create_aqua_evaluation_details.shape_name)
379391
.with_block_storage_size(create_aqua_evaluation_details.block_storage_size)
380392
.with_freeform_tag(**evaluation_job_freeform_tags)
393+
.with_defined_tag(**(create_aqua_evaluation_details.defined_tags or {}))
381394
)
382395
if (
383396
create_aqua_evaluation_details.memory_in_gbs
@@ -424,6 +437,7 @@ def create(
424437
evaluation_job_run = evaluation_job.run(
425438
name=evaluation_model.display_name,
426439
freeform_tags=evaluation_job_freeform_tags,
440+
defined_tags=(create_aqua_evaluation_details.defined_tags or {}),
427441
wait=False,
428442
)
429443
logger.debug(
@@ -443,13 +457,20 @@ def create(
443457
for metadata in evaluation_model_custom_metadata.to_dict()["data"]
444458
]
445459

460+
evaluation_model_freeform_tags = {
461+
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
462+
**(create_aqua_evaluation_details.freeform_tags or {}),
463+
}
464+
evaluation_model_defined_tags = (
465+
create_aqua_evaluation_details.defined_tags or {}
466+
)
467+
446468
self.ds_client.update_model(
447469
model_id=evaluation_model.id,
448470
update_model_details=UpdateModelDetails(
449471
custom_metadata_list=updated_custom_metadata_list,
450-
freeform_tags={
451-
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
452-
},
472+
freeform_tags=evaluation_model_freeform_tags,
473+
defined_tags=evaluation_model_defined_tags,
453474
),
454475
)
455476

@@ -523,6 +544,8 @@ def create(
523544
"evaluation_job_id": evaluation_job.id,
524545
"evaluation_source": create_aqua_evaluation_details.evaluation_source_id,
525546
"evaluation_experiment_id": experiment_model_version_set_id,
547+
**evaluation_model_freeform_tags,
548+
**evaluation_model_defined_tags,
526549
},
527550
parameters=AquaEvalParams(),
528551
)
@@ -619,11 +642,6 @@ def _build_launch_cmd(
619642
evaluation_id=evaluation_id,
620643
evaluation_target_id=evaluation_source_id,
621644
input_data={
622-
"columns": {
623-
"prompt": "prompt",
624-
"completion": "completion",
625-
"category": "category",
626-
},
627645
"format": Path(dataset_path).suffix,
628646
"url": dataset_path,
629647
},

ads/aqua/extension/deployment_handler.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def delete(self, model_deployment_id):
5959
return self.finish(AquaDeploymentApp().delete(model_deployment_id))
6060

6161
@handle_exceptions
62-
def put(self, *args, **kwargs):
62+
def put(self, *args, **kwargs): # noqa: ARG002
6363
"""
6464
Handles put request for the activating and deactivating OCI datascience model deployments
6565
Raises
@@ -82,7 +82,7 @@ def put(self, *args, **kwargs):
8282
raise HTTPError(400, f"The request {self.request.path} is invalid.")
8383

8484
@handle_exceptions
85-
def post(self, *args, **kwargs):
85+
def post(self, *args, **kwargs): # noqa: ARG002
8686
"""
8787
Handles post request for the deployment APIs
8888
Raises
@@ -132,6 +132,8 @@ def post(self, *args, **kwargs):
132132
private_endpoint_id = input_data.get("private_endpoint_id")
133133
container_image_uri = input_data.get("container_image_uri")
134134
cmd_var = input_data.get("cmd_var")
135+
freeform_tags = input_data.get("freeform_tags")
136+
defined_tags = input_data.get("defined_tags")
135137

136138
self.finish(
137139
AquaDeploymentApp().create(
@@ -157,6 +159,8 @@ def post(self, *args, **kwargs):
157159
private_endpoint_id=private_endpoint_id,
158160
container_image_uri=container_image_uri,
159161
cmd_var=cmd_var,
162+
freeform_tags=freeform_tags,
163+
defined_tags=defined_tags,
160164
)
161165
)
162166

@@ -196,7 +200,7 @@ def validate_predict_url(endpoint):
196200
return False
197201

198202
@handle_exceptions
199-
def post(self, *args, **kwargs):
203+
def post(self, *args, **kwargs): # noqa: ARG002
200204
"""
201205
Handles inference request for the Active Model Deployments
202206
Raises
@@ -262,7 +266,7 @@ def get(self, model_id):
262266
)
263267

264268
@handle_exceptions
265-
def post(self, *args, **kwargs):
269+
def post(self, *args, **kwargs): # noqa: ARG002
266270
"""Handles post request for the deployment param handler API.
267271
268272
Raises

ads/aqua/extension/model_handler.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def list(self):
9696
)
9797

9898
@handle_exceptions
99-
def post(self, *args, **kwargs):
99+
def post(self, *args, **kwargs): # noqa: ARG002
100100
"""
101101
Handles post request for the registering any Aqua model.
102102
Raises
@@ -131,6 +131,8 @@ def post(self, *args, **kwargs):
131131
inference_container_uri = input_data.get("inference_container_uri")
132132
allow_patterns = input_data.get("allow_patterns")
133133
ignore_patterns = input_data.get("ignore_patterns")
134+
freeform_tags = input_data.get("freeform_tags")
135+
defined_tags = input_data.get("defined_tags")
134136

135137
return self.finish(
136138
AquaModelApp().register(
@@ -145,6 +147,8 @@ def post(self, *args, **kwargs):
145147
inference_container_uri=inference_container_uri,
146148
allow_patterns=allow_patterns,
147149
ignore_patterns=ignore_patterns,
150+
freeform_tags=freeform_tags,
151+
defined_tags=defined_tags,
148152
)
149153
)
150154

@@ -170,11 +174,9 @@ def put(self, id):
170174

171175
enable_finetuning = input_data.get("enable_finetuning")
172176
task = input_data.get("task")
173-
app=AquaModelApp()
177+
app = AquaModelApp()
174178
self.finish(
175-
app.edit_registered_model(
176-
id, inference_container, enable_finetuning, task
177-
)
179+
app.edit_registered_model(id, inference_container, enable_finetuning, task)
178180
)
179181
app.clear_model_details_cache(model_id=id)
180182

@@ -218,7 +220,7 @@ def _find_matching_aqua_model(model_id: str) -> Optional[AquaModelSummary]:
218220
return None
219221

220222
@handle_exceptions
221-
def get(self, *args, **kwargs):
223+
def get(self, *args, **kwargs): # noqa: ARG002
222224
"""
223225
Finds a list of matching models from hugging face based on query string provided from users.
224226
@@ -239,7 +241,7 @@ def get(self, *args, **kwargs):
239241
return self.finish({"models": models})
240242

241243
@handle_exceptions
242-
def post(self, *args, **kwargs):
244+
def post(self, *args, **kwargs): # noqa: ARG002
243245
"""Handles post request for the HF Models APIs
244246
245247
Raises

ads/aqua/extension/ui_handler.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def get(self, id=""):
6868
return self.list_buckets()
6969
elif paths.startswith("aqua/job/shapes"):
7070
return self.list_job_shapes()
71+
elif paths.startswith("aqua/modeldeployment/shapes"):
72+
return self.list_model_deployment_shapes()
7173
elif paths.startswith("aqua/vcn"):
7274
return self.list_vcn()
7375
elif paths.startswith("aqua/subnets"):
@@ -160,6 +162,15 @@ def list_job_shapes(self, **kwargs):
160162
AquaUIApp().list_job_shapes(compartment_id=compartment_id, **kwargs)
161163
)
162164

165+
def list_model_deployment_shapes(self, **kwargs):
166+
"""Lists model deployment shapes available in the specified compartment."""
167+
compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
168+
return self.finish(
169+
AquaUIApp().list_model_deployment_shapes(
170+
compartment_id=compartment_id, **kwargs
171+
)
172+
)
173+
163174
def list_vcn(self, **kwargs):
164175
"""Lists the virtual cloud networks (VCNs) in the specified compartment."""
165176
compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
@@ -255,8 +266,9 @@ def post(self, *args, **kwargs):
255266
__handlers__ = [
256267
("logging/?([^/]*)", AquaUIHandler),
257268
("compartments/?([^/]*)", AquaUIHandler),
258-
# TODO: change url to evaluation/experiements/?([^/]*)
269+
# TODO: change url to evaluation/experiments/?([^/]*)
259270
("experiment/?([^/]*)", AquaUIHandler),
271+
("modeldeployment/?([^/]*)", AquaUIHandler),
260272
("versionsets/?([^/]*)", AquaUIHandler),
261273
("buckets/?([^/]*)", AquaUIHandler),
262274
("job/shapes/?([^/]*)", AquaUIHandler),

ads/aqua/finetuning/entities.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ class CreateFineTuningDetails(DataClassSerializable):
8080
The log id for fine tuning job infrastructure.
8181
force_overwrite: (bool, optional). Defaults to `False`.
8282
Whether to force overwrite the existing file in object storage.
83+
freeform_tags: (dict, optional)
84+
Freeform tags for the fine-tuning model
85+
defined_tags: (dict, optional)
86+
Defined tags for the fine-tuning model
8387
"""
8488

8589
ft_source_id: str
@@ -101,3 +105,5 @@ class CreateFineTuningDetails(DataClassSerializable):
101105
log_id: Optional[str] = None
102106
log_group_id: Optional[str] = None
103107
force_overwrite: Optional[bool] = False
108+
freeform_tags: Optional[dict] = None
109+
defined_tags: Optional[dict] = None

0 commit comments

Comments
 (0)