Skip to content

Commit 707e96c

Browse files
Adding default chat_template api
1 parent b970eb5 commit 707e96c

File tree

3 files changed

+77
-1
lines changed

3 files changed

+77
-1
lines changed

ads/aqua/app.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
get_artifact_path,
2020
is_valid_ocid,
2121
load_config,
22+
read_file,
2223
)
2324
from ads.aqua.constants import UNKNOWN
2425
from ads.common import oci_client as oc
@@ -328,6 +329,53 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict:
328329

329330
return config
330331

332+
def get_chat_template(self, model_id):
333+
"""Gets the default chat template for the given Aqua model.
334+
335+
Parameters
336+
----------
337+
model_id: str
338+
The OCID of the Aqua model.
339+
340+
Returns
341+
-------
342+
str:
343+
Chat template string.
344+
"""
345+
chat_template = ""
346+
oci_model = self.ds_client.get_model(model_id).data
347+
oci_aqua = (
348+
(
349+
Tags.AQUA_TAG in oci_model.freeform_tags
350+
or Tags.AQUA_TAG.lower() in oci_model.freeform_tags
351+
)
352+
if oci_model.freeform_tags
353+
else False
354+
)
355+
356+
if not oci_aqua:
357+
raise AquaRuntimeError(f"Target model {oci_model.id} is not Aqua model.")
358+
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
359+
if not artifact_path:
360+
logger.debug(
361+
f"Failed to get artifact path from custom metadata for the model: {model_id}"
362+
)
363+
return chat_template
364+
365+
try:
366+
tokenizer_path = f"{os.path.dirname(artifact_path)}/tokenizer_config.json"
367+
chat_template = read_file(tokenizer_path)
368+
except Exception:
369+
pass
370+
371+
if not chat_template:
372+
logger.error(
373+
f"No default chat template is available for the model: {model_id}."
374+
)
375+
return chat_template
376+
377+
return chat_template
378+
331379
@property
332380
def telemetry(self):
333381
if not self._telemetry:

ads/aqua/common/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,10 @@ def read_file(file_path: str, **kwargs) -> str:
239239

240240
@threaded()
241241
def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
242-
artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
242+
if config_file_name:
243+
artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
244+
else:
245+
artifact_path = f"{file_path.rstrip('/')}"
243246
signer = default_signer() if artifact_path.startswith("oci://") else {}
244247
config = json.loads(
245248
read_file(file_path=artifact_path, auth=signer, **kwargs) or UNKNOWN_JSON_STR

ads/aqua/extension/model_handler.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def get(
3434
"""Handle GET request."""
3535
url_parse = urlparse(self.request.path)
3636
paths = url_parse.path.strip("/")
37+
path_list = paths.split("/")
38+
print(path_list)
3739
if paths.startswith("aqua/model/files"):
3840
os_path = self.get_argument("os_path", None)
3941
model_name = self.get_argument("model_name", None)
@@ -63,6 +65,12 @@ def get(
6365
"os_path", "model_name"
6466
),
6567
)
68+
elif (
69+
len(path_list) == 4
70+
and path_list[2].startswith("ocid1.datasciencemodel")
71+
and path_list[3] == "chat_templates"
72+
):
73+
return self.get_chat_template(model_id)
6674
elif not model_id:
6775
return self.list()
6876

@@ -316,8 +324,25 @@ def post(self, *args, **kwargs): # noqa: ARG002
316324
)
317325

318326

327+
class AquaModelChatTemplateHandler(AquaAPIhandler):
328+
def get(self, model_id):
329+
url_parse = urlparse(self.request.path)
330+
paths = url_parse.path.strip("/")
331+
path_list = paths.split("/")
332+
print(path_list)
333+
if (
334+
len(path_list) == 4
335+
and path_list[2].startswith("ocid1.datasciencemodel")
336+
and path_list[3] == "chat_template"
337+
):
338+
return self.finish(AquaModelApp().get_chat_template(model_id))
339+
else:
340+
raise HTTPError(400, f"The request {self.request.path} is invalid.")
341+
342+
319343
__handlers__ = [
320344
("model/?([^/]*)", AquaModelHandler),
321345
("model/?([^/]*)/license", AquaModelLicenseHandler),
346+
("model/?([^/]*)/chat_template", AquaModelChatTemplateHandler),
322347
("model/hf/search/?([^/]*)", AquaHuggingFaceHandler),
323348
]

0 commit comments

Comments
 (0)