Skip to content

Commit 4ada401

Browse files
Resolving conflicts
2 parents c7318aa + 2098b24 commit 4ada401

File tree

11 files changed

+409
-112
lines changed

11 files changed

+409
-112
lines changed

ads/aqua/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
SUPPORTED_FILE_FORMATS = ["jsonl"]
5656
MODEL_BY_REFERENCE_OSS_PATH_KEY = "artifact_location"
5757

58+
AQUA_CHAT_TEMPLATE_METADATA_KEY = "chat_template"
59+
5860
CONSOLE_LINK_RESOURCE_TYPE_MAPPING = {
5961
"datasciencemodel": "models",
6062
"datasciencemodeldeployment": "model-deployments",

ads/aqua/extension/common_handler.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#!/usr/bin/env python
22
# Copyright (c) 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4-
5-
4+
import json
5+
import os
66
from importlib import metadata
77

88
import huggingface_hub
@@ -18,6 +18,10 @@
1818
)
1919
from ads.aqua.extension.base_handler import AquaAPIhandler
2020
from ads.aqua.extension.errors import Errors
21+
from ads.common.object_storage_details import ObjectStorageDetails
22+
from ads.common.utils import read_file
23+
from ads.config import CONDA_BUCKET_NAME, CONDA_BUCKET_NS
24+
from ads.opctl.operator.common.utils import default_signer
2125

2226

2327
class ADSVersionHandler(AquaAPIhandler):
@@ -28,6 +32,46 @@ def get(self):
2832
self.finish({"data": metadata.version("oracle_ads")})
2933

3034

35+
class AquaVersionHandler(AquaAPIhandler):
36+
@handle_exceptions
37+
def get(self):
38+
"""
39+
Returns the current and latest deployed version of AQUA
40+
41+
{
42+
"installed": {
43+
"aqua": "0.1.3.0",
44+
"ads": "2.14.2"
45+
},
46+
"latest": {
47+
"aqua": "0.1.4.0",
48+
"ads": "2.14.4"
49+
}
50+
}
51+
52+
"""
53+
54+
current_aqua_version_path = os.path.join(
55+
os.path.dirname(os.path.abspath(__file__)), "..", "version.json"
56+
)
57+
current_aqua_version = json.loads(read_file(current_aqua_version_path))
58+
current_ads_version = {"ads": metadata.version("oracle_ads")}
59+
current_version = {"installed": {**current_aqua_version, **current_ads_version}}
60+
try:
61+
latest_version_artifact_path = ObjectStorageDetails(
62+
CONDA_BUCKET_NAME,
63+
CONDA_BUCKET_NS,
64+
"service_pack/aqua_latest_version.json",
65+
).path
66+
latest_version = json.loads(
67+
read_file(latest_version_artifact_path, auth=default_signer())
68+
)
69+
except Exception:
70+
latest_version = {"latest": current_version["installed"]}
71+
response = {**current_version, **latest_version}
72+
return self.finish(response)
73+
74+
3175
class CompatibilityCheckHandler(AquaAPIhandler):
3276
"""The handler to check if the extension is compatible."""
3377

@@ -118,4 +162,5 @@ def get(self):
118162
("network_status", NetworkStatusHandler),
119163
("hf_login", HFLoginHandler),
120164
("hf_logged_in", HFUserStatusHandler),
165+
("aqua_version", AquaVersionHandler),
121166
]

ads/aqua/extension/model_handler.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111
from ads.aqua.common.enums import CustomInferenceContainerTypeFamily
1212
from ads.aqua.common.errors import AquaRuntimeError
1313
from ads.aqua.common.utils import get_hf_model_info, is_valid_ocid, list_hf_models
14+
from ads.aqua.constants import AQUA_CHAT_TEMPLATE_METADATA_KEY
1415
from ads.aqua.extension.base_handler import AquaAPIhandler
1516
from ads.aqua.extension.errors import Errors
1617
from ads.aqua.model import AquaModelApp
1718
from ads.aqua.model.entities import AquaModelSummary, HFModelSummary
1819
from ads.config import SERVICE
20+
from ads.model import DataScienceModel
1921
from ads.model.common.utils import MetadataArtifactPathType
22+
from ads.model.service.oci_datascience_model import OCIDataScienceModel
2023

2124

2225
class AquaModelHandler(AquaAPIhandler):
@@ -320,26 +323,65 @@ def post(self, *args, **kwargs): # noqa: ARG002
320323
)
321324

322325

323-
class AquaModelTokenizerConfigHandler(AquaAPIhandler):
326+
class AquaModelChatTemplateHandler(AquaAPIhandler):
324327
def get(self, model_id):
325328
"""
326-
Handles requests for retrieving the Hugging Face tokenizer configuration of a specified model.
327-
Expected request format: GET /aqua/models/<model-ocid>/tokenizer
329+
Handles requests for retrieving the chat template from custom metadata of a specified model.
330+
Expected request format: GET /aqua/models/<model-ocid>/chat-template
328331
329332
"""
330333

331334
path_list = urlparse(self.request.path).path.strip("/").split("/")
332-
# Path should be /aqua/models/ocid1.iad.ahdxxx/tokenizer
333-
# path_list=['aqua','models','<model-ocid>','tokenizer']
335+
# Path should be /aqua/models/ocid1.iad.ahdxxx/chat-template
336+
# path_list=['aqua','models','<model-ocid>','chat-template']
334337
if (
335338
len(path_list) == 4
336339
and is_valid_ocid(path_list[2])
337-
and path_list[3] == "tokenizer"
340+
and path_list[3] == "chat-template"
338341
):
339-
return self.finish(AquaModelApp().get_hf_tokenizer_config(model_id))
342+
try:
343+
oci_data_science_model = OCIDataScienceModel.from_id(model_id)
344+
except Exception as e:
345+
raise HTTPError(404, f"Model not found for id: {model_id}. Details: {str(e)}")
346+
return self.finish(oci_data_science_model.get_custom_metadata_artifact("chat_template"))
340347

341348
raise HTTPError(400, f"The request {self.request.path} is invalid.")
342349

350+
@handle_exceptions
351+
def post(self, model_id: str):
352+
"""
353+
Handles POST requests to add a custom chat_template metadata artifact to a model.
354+
355+
Expected request format:
356+
POST /aqua/models/<model-ocid>/chat-template
357+
Body: { "chat_template": "<your_template_string>" }
358+
359+
"""
360+
try:
361+
input_body = self.get_json_body()
362+
except Exception as e:
363+
raise HTTPError(400, f"Invalid JSON body: {str(e)}")
364+
365+
chat_template = input_body.get("chat_template")
366+
if not chat_template:
367+
raise HTTPError(400, "Missing required field: 'chat_template'")
368+
369+
try:
370+
data_science_model = DataScienceModel.from_id(model_id)
371+
except Exception as e:
372+
raise HTTPError(404, f"Model not found for id: {model_id}. Details: {str(e)}")
373+
374+
try:
375+
result = data_science_model.create_custom_metadata_artifact(
376+
metadata_key_name=AQUA_CHAT_TEMPLATE_METADATA_KEY,
377+
path_type=MetadataArtifactPathType.CONTENT,
378+
artifact_path_or_content=chat_template.encode()
379+
)
380+
except Exception as e:
381+
raise HTTPError(500, f"Failed to create metadata artifact: {str(e)}")
382+
383+
return self.finish(result)
384+
343385

344386
class AquaModelDefinedMetadataArtifactHandler(AquaAPIhandler):
345387
"""
@@ -381,7 +423,7 @@ def post(self, model_id: str, metadata_key: str):
381423
("model/?([^/]*)", AquaModelHandler),
382424
("model/?([^/]*)/license", AquaModelLicenseHandler),
383425
("model/?([^/]*)/readme", AquaModelReadmeHandler),
384-
("model/?([^/]*)/tokenizer", AquaModelTokenizerConfigHandler),
426+
("model/?([^/]*)/chat-template", AquaModelChatTemplateHandler),
385427
("model/hf/search/?([^/]*)", AquaHuggingFaceHandler),
386428
(
387429
"model/?([^/]*)/definedMetadata/?([^/]*)",

ads/aqua/model/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class ModelTask(ExtendedEnum):
2626
TEXT_GENERATION = "text-generation"
2727
IMAGE_TEXT_TO_TEXT = "image-text-to-text"
2828
IMAGE_TO_TEXT = "image-to-text"
29+
TIME_SERIES_FORECASTING = "time-series-forecasting"
2930

3031

3132
class FineTuningMetricCategories(ExtendedEnum):

ads/aqua/modeldeployment/deployment.py

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55

66
import json
7+
import re
78
import shlex
89
import threading
910
from datetime import datetime, timedelta
@@ -47,7 +48,11 @@
4748
)
4849
from ads.aqua.data import AquaResourceIdentifier
4950
from ads.aqua.model import AquaModelApp
50-
from ads.aqua.model.constants import AquaModelMetadataKeys, ModelCustomMetadataFields
51+
from ads.aqua.model.constants import (
52+
AquaModelMetadataKeys,
53+
ModelCustomMetadataFields,
54+
ModelTask,
55+
)
5156
from ads.aqua.model.utils import (
5257
extract_base_model_from_ft,
5358
extract_fine_tune_artifacts_path,
@@ -214,6 +219,14 @@ def create(
214219
freeform_tags=freeform_tags,
215220
defined_tags=defined_tags,
216221
)
222+
task_tag = aqua_model.freeform_tags.get(Tags.TASK, UNKNOWN)
223+
if (
224+
task_tag == ModelTask.TIME_SERIES_FORECASTING
225+
or task_tag == ModelTask.TIME_SERIES_FORECASTING.replace("-", "_")
226+
):
227+
create_deployment_details.env_var.update(
228+
{Tags.TASK.upper(): ModelTask.TIME_SERIES_FORECASTING}
229+
)
217230
return self._create(
218231
aqua_model=aqua_model,
219232
create_deployment_details=create_deployment_details,
@@ -752,14 +765,16 @@ def _create_deployment(
752765
).deploy(wait_for_completion=False)
753766

754767
deployment_id = deployment.id
768+
755769
logger.info(
756770
f"Aqua model deployment {deployment_id} created for model {aqua_model_id}. Work request Id is {deployment.dsc_model_deployment.workflow_req_id}"
757771
)
772+
status_list = []
758773

759774
progress_thread = threading.Thread(
760775
target=self.get_deployment_status,
761776
args=(
762-
deployment_id,
777+
deployment,
763778
deployment.dsc_model_deployment.workflow_req_id,
764779
model_type,
765780
model_name,
@@ -1265,7 +1280,7 @@ def list_shapes(self, **kwargs) -> List[ComputeShapeSummary]:
12651280

12661281
def get_deployment_status(
12671282
self,
1268-
model_deployment_id: str,
1283+
deployment: ModelDeployment,
12691284
work_request_id: str,
12701285
model_type: str,
12711286
model_name: str,
@@ -1287,37 +1302,60 @@ def get_deployment_status(
12871302
AquaDeployment
12881303
An Aqua deployment instance.
12891304
"""
1290-
ocid = get_ocid_substring(model_deployment_id, key_len=8)
1291-
telemetry_kwargs = {"ocid": ocid}
1292-
1305+
ocid = get_ocid_substring(deployment.id, key_len=8)
12931306
data_science_work_request: DataScienceWorkRequest = DataScienceWorkRequest(
12941307
work_request_id
12951308
)
1296-
12971309
try:
12981310
data_science_work_request.wait_work_request(
12991311
progress_bar_description="Creating model deployment",
13001312
max_wait_time=DEFAULT_WAIT_TIME,
13011313
poll_interval=DEFAULT_POLL_INTERVAL,
13021314
)
13031315
except Exception:
1316+
status = ""
1317+
logs = deployment.show_logs().sort_values(by="time", ascending=False)
1318+
1319+
if logs and len(logs) > 0:
1320+
status = logs.iloc[0]["message"]
1321+
1322+
status = re.sub(r"[^a-zA-Z0-9]", " ", status)
1323+
13041324
if data_science_work_request._error_message:
13051325
error_str = ""
13061326
for error in data_science_work_request._error_message:
13071327
error_str = error_str + " " + error.message
13081328

1309-
self.telemetry.record_event(
1310-
category=f"aqua/{model_type}/deployment/status",
1311-
action="FAILED",
1312-
detail=error_str,
1313-
value=model_name,
1314-
**telemetry_kwargs,
1315-
)
1329+
error_str = re.sub(r"[^a-zA-Z0-9]", " ", error_str)
1330+
telemetry_kwargs = {
1331+
"ocid": ocid,
1332+
"model_name": model_name,
1333+
"work_request_error": error_str,
1334+
"status": status,
1335+
}
1336+
1337+
self.telemetry.record_event(
1338+
category=f"aqua/{model_type}/deployment/status",
1339+
action="FAILED",
1340+
**telemetry_kwargs,
1341+
)
1342+
else:
1343+
telemetry_kwargs = {
1344+
"ocid": ocid,
1345+
"model_name": model_name,
1346+
"status": status,
1347+
}
1348+
1349+
self.telemetry.record_event(
1350+
category=f"aqua/{model_type}/deployment/status",
1351+
action="FAILED",
1352+
**telemetry_kwargs,
1353+
)
13161354

13171355
else:
1318-
self.telemetry.record_event_async(
1356+
telemetry_kwargs = {"ocid": ocid, "model_name": model_name}
1357+
self.telemetry.record_event(
13191358
category=f"aqua/{model_type}/deployment/status",
13201359
action="SUCCEEDED",
1321-
value=model_name,
13221360
**telemetry_kwargs,
13231361
)

ads/aqua/version.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"aqua": "1.0.7"
3+
}

ads/common/oci_logging.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,21 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8; -*-
32

43
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76
import datetime
87
import logging
98
import time
10-
from typing import Dict, Union, List
9+
from typing import Dict, List, Union
1110

11+
import oci.exceptions
1212
import oci.logging
1313
import oci.loggingsearch
14-
import oci.exceptions
14+
1515
from ads.common.decorator.utils import class_or_instance_method
1616
from ads.common.oci_mixin import OCIModelMixin, OCIWorkRequestMixin
1717
from ads.common.oci_resource import OCIResource, ResourceNotFoundError
1818

19-
2019
logger = logging.getLogger(__name__)
2120

2221
# Maximum number of log records to be returned by default.
@@ -862,9 +861,7 @@ def tail(
862861
time_start=time_start,
863862
log_filter=log_filter,
864863
)
865-
self._print(
866-
sorted(tail_logs, key=lambda log: log["time"])
867-
)
864+
self._print(sorted(tail_logs, key=lambda log: log["time"]))
868865

869866
def head(
870867
self,

0 commit comments

Comments
 (0)