Skip to content

Commit 3cc7e8a

Browse files
Adding chat template api
1 parent 30592c8 commit 3cc7e8a

File tree

3 files changed

+12
-15
lines changed

3 files changed

+12
-15
lines changed

ads/aqua/app.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
get_artifact_path,
2020
is_valid_ocid,
2121
load_config,
22-
read_file,
2322
)
2423
from ads.aqua.constants import UNKNOWN
2524
from ads.common import oci_client as oc
@@ -363,18 +362,20 @@ def get_chat_template(self, model_id):
363362
return chat_template
364363

365364
try:
366-
tokenizer_path = f"{os.path.dirname(artifact_path)}/tokenizer_config.json"
367-
chat_template = read_file(tokenizer_path)
365+
tokenizer_path = f"{os.path.dirname(artifact_path)}/artifact"
366+
chat_template = load_config(
367+
file_path=tokenizer_path, config_file_name="tokenizer_config.json"
368+
)
368369
except Exception:
369-
pass
370+
logger.error(
371+
f"Error reading tokenizer_config.json file for the model: {model_id}"
372+
)
370373

371374
if not chat_template:
372375
logger.error(
373376
f"No default chat template is available for the model: {model_id}."
374377
)
375-
return chat_template
376-
377-
return chat_template
378+
return {"chat_template": chat_template.get("chat_template")}
378379

379380
@property
380381
def telemetry(self):

ads/aqua/common/utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def get_artifact_path(custom_metadata_list: List) -> str:
228228
return UNKNOWN
229229

230230

231-
def read_file(file_path: str, **kwargs) -> str:
231+
def read_file(file_path: str, **kwargs) -> Union[str, dict]:
232232
try:
233233
with fsspec.open(file_path, "r", **kwargs.get("auth", {})) as f:
234234
return f.read()
@@ -239,10 +239,7 @@ def read_file(file_path: str, **kwargs) -> str:
239239

240240
@threaded()
241241
def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
242-
if config_file_name:
243-
artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
244-
else:
245-
artifact_path = f"{file_path.rstrip('/')}"
242+
artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
246243
signer = default_signer() if artifact_path.startswith("oci://") else {}
247244
config = json.loads(
248245
read_file(file_path=artifact_path, auth=signer, **kwargs) or UNKNOWN_JSON_STR

ads/aqua/extension/model_handler.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1515
from ads.aqua.common.utils import (
1616
get_hf_model_info,
17+
is_valid_ocid,
1718
list_hf_models,
1819
)
1920
from ads.aqua.extension.base_handler import AquaAPIhandler
@@ -35,7 +36,6 @@ def get(
3536
url_parse = urlparse(self.request.path)
3637
paths = url_parse.path.strip("/")
3738
path_list = paths.split("/")
38-
print(path_list)
3939
if paths.startswith("aqua/model/files"):
4040
os_path = self.get_argument("os_path", None)
4141
model_name = self.get_argument("model_name", None)
@@ -329,10 +329,9 @@ def get(self, model_id):
329329
url_parse = urlparse(self.request.path)
330330
paths = url_parse.path.strip("/")
331331
path_list = paths.split("/")
332-
print(path_list)
333332
if (
334333
len(path_list) == 4
335-
and path_list[2].startswith("ocid1.datasciencemodel")
334+
and is_valid_ocid(path_list[2])
336335
and path_list[3] == "chat_template"
337336
):
338337
return self.finish(AquaModelApp().get_chat_template(model_id))

0 commit comments

Comments
 (0)