Skip to content

Commit 8997594

Browse files
Addressing review comments
1 parent 74601c0 commit 8997594

File tree

5 files changed

+72
-57
lines changed

5 files changed

+72
-57
lines changed

ads/aqua/app.py

Lines changed: 12 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import os
77
import traceback
88
from dataclasses import fields
9-
from typing import Dict, Union
9+
from typing import Dict, Optional, Union
1010

1111
import oci
1212
from oci.data_science.models import UpdateModelDetails, UpdateModelProvenanceDetails
@@ -268,7 +268,12 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
268268
logger.info(f"Artifact not found in model {model_id}.")
269269
return False
270270

271-
def get_config(self, model_id: str, config_file_name: str) -> Dict:
271+
def get_config(
272+
self,
273+
model_id: str,
274+
config_file_name: str,
275+
config_folder: Optional[str] = "config",
276+
) -> Dict:
272277
"""Gets the config for the given Aqua model.
273278
274279
Parameters
@@ -277,6 +282,9 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict:
277282
The OCID of the Aqua model.
278283
config_file_name: str
279284
name of the config file
285+
config_folder: Optional[str]
286+
subfolder path where config_file_name needs to be searched
287+
default value: config
280288
281289
Returns
282290
-------
@@ -306,11 +314,11 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict:
306314
)
307315
base_model = self.ds_client.get_model(base_model_ocid).data
308316
artifact_path = get_artifact_path(base_model.custom_metadata_list)
309-
config_path = f"{os.path.dirname(artifact_path)}/config/"
317+
config_path = f"{os.path.dirname(artifact_path)}/{config_folder}/"
310318
else:
311319
logger.info(f"Loading {config_file_name} for model {oci_model.id}...")
312320
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
313-
config_path = f"{artifact_path.rstrip('/')}/config/"
321+
config_path = f"{artifact_path.rstrip('/')}/{config_folder}/"
314322

315323
if not artifact_path:
316324
logger.debug(
@@ -340,55 +348,6 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict:
340348

341349
return config
342350

343-
def get_chat_template(self, model_id):
344-
"""Gets the default chat template for the given Aqua model.
345-
346-
Parameters
347-
----------
348-
model_id: str
349-
The OCID of the Aqua model.
350-
351-
Returns
352-
-------
353-
str:
354-
Chat template string.
355-
"""
356-
chat_template = ""
357-
oci_model = self.ds_client.get_model(model_id).data
358-
oci_aqua = (
359-
(
360-
Tags.AQUA_TAG in oci_model.freeform_tags
361-
or Tags.AQUA_TAG.lower() in oci_model.freeform_tags
362-
)
363-
if oci_model.freeform_tags
364-
else False
365-
)
366-
367-
if not oci_aqua:
368-
raise AquaRuntimeError(f"Target model {oci_model.id} is not Aqua model.")
369-
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
370-
if not artifact_path:
371-
logger.debug(
372-
f"Failed to get artifact path from custom metadata for the model: {model_id}"
373-
)
374-
return chat_template
375-
376-
try:
377-
tokenizer_path = f"{os.path.dirname(artifact_path)}/artifact"
378-
chat_template = load_config(
379-
file_path=tokenizer_path, config_file_name="tokenizer_config.json"
380-
)
381-
except Exception:
382-
logger.error(
383-
f"Error reading tokenizer_config.json file for the model: {model_id}"
384-
)
385-
386-
if not chat_template:
387-
logger.error(
388-
f"No default chat template is available for the model: {model_id}."
389-
)
390-
return {"chat_template": chat_template.get("chat_template")}
391-
392351
@property
393352
def telemetry(self):
394353
if not self._telemetry:

ads/aqua/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
README = "README.md"
1111
LICENSE_TXT = "config/LICENSE.txt"
1212
DEPLOYMENT_CONFIG = "deployment_config.json"
13+
AQUA_MODEL_TOKENIZER_CONFIG = "tokenizer_config.json"
1314
COMPARTMENT_MAPPING_KEY = "service-model-compartment"
1415
CONTAINER_INDEX = "container_index.json"
1516
EVALUATION_REPORT_JSON = "report.json"

ads/aqua/extension/model_handler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def post(self, *args, **kwargs): # noqa: ARG002
317317
)
318318

319319

320-
class AquaModelChatTemplateHandler(AquaAPIhandler):
320+
class AquaModelTokenizerConfigHandler(AquaAPIhandler):
321321
def get(self, model_id):
322322
url_parse = urlparse(self.request.path)
323323
paths = url_parse.path.strip("/")
@@ -327,14 +327,14 @@ def get(self, model_id):
327327
and is_valid_ocid(path_list[2])
328328
and path_list[3] == "chat_template"
329329
):
330-
return self.finish(AquaModelApp().get_chat_template(model_id))
330+
return self.finish(AquaModelApp().get_hf_tokenizer_config(model_id))
331331
else:
332332
raise HTTPError(400, f"The request {self.request.path} is invalid.")
333333

334334

335335
__handlers__ = [
336336
("model/?([^/]*)", AquaModelHandler),
337337
("model/?([^/]*)/license", AquaModelLicenseHandler),
338-
("model/?([^/]*)/chat_template", AquaModelChatTemplateHandler),
338+
("model/?([^/]*)/chat_template", AquaModelTokenizerConfigHandler),
339339
("model/hf/search/?([^/]*)", AquaHuggingFaceHandler),
340340
]

ads/aqua/model/model.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME,
4545
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE,
4646
AQUA_MODEL_ARTIFACT_FILE,
47+
AQUA_MODEL_TOKENIZER_CONFIG,
4748
AQUA_MODEL_TYPE_CUSTOM,
4849
HF_METADATA_FOLDER,
4950
LICENSE_TXT,
@@ -568,6 +569,24 @@ def _build_ft_metrics(
568569
training_final,
569570
]
570571

572+
def get_hf_tokenizer_config(self, model_id):
573+
"""Gets the default chat template for the given Aqua model.
574+
575+
Parameters
576+
----------
577+
model_id: str
578+
The OCID of the Aqua model.
579+
580+
Returns
581+
-------
582+
str:
583+
Chat template string.
584+
"""
585+
config = self.get_config(model_id, AQUA_MODEL_TOKENIZER_CONFIG, "artifact")
586+
if not config:
587+
logger.debug(f"Tokenizer config for model: {model_id} is not available.")
588+
return config
589+
571590
@staticmethod
572591
def to_aqua_model(
573592
model: Union[

tests/unitary/with_extras/aqua/test_model_handler.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pytest
1010
from huggingface_hub.hf_api import HfApi, ModelInfo
1111
from huggingface_hub.utils import GatedRepoError
12-
from notebook.base.handlers import IPythonHandler
12+
from notebook.base.handlers import IPythonHandler, HTTPError
1313
from parameterized import parameterized
1414

1515
from ads.aqua.common.errors import AquaRuntimeError
@@ -18,6 +18,7 @@
1818
AquaHuggingFaceHandler,
1919
AquaModelHandler,
2020
AquaModelLicenseHandler,
21+
AquaModelTokenizerConfigHandler,
2122
)
2223
from ads.aqua.model import AquaModelApp
2324
from ads.aqua.model.entities import AquaModel, AquaModelSummary, HFModelSummary
@@ -250,6 +251,41 @@ def test_get(self, mock_load_license):
250251
mock_load_license.assert_called_with("test_model_id")
251252

252253

254+
class ModelTokenizerConfigHandlerTestCase(TestCase):
255+
@patch.object(IPythonHandler, "__init__")
256+
def setUp(self, ipython_init_mock) -> None:
257+
ipython_init_mock.return_value = None
258+
self.model_tokenizer_config_handler = AquaModelTokenizerConfigHandler(
259+
MagicMock(), MagicMock()
260+
)
261+
self.model_tokenizer_config_handler.finish = MagicMock()
262+
self.model_tokenizer_config_handler.request = MagicMock()
263+
264+
@patch.object(AquaModelApp, "get_hf_tokenizer_config")
265+
@patch("ads.aqua.extension.model_handler.urlparse")
266+
def test_get(self, mock_urlparse, mock_get_hf_tokenizer_config):
267+
request_path = MagicMock(path="aqua/model/ocid1.xx./chat_template")
268+
mock_urlparse.return_value = request_path
269+
self.model_tokenizer_config_handler.get(model_id="test_model_id")
270+
self.model_tokenizer_config_handler.finish.assert_called_with(
271+
mock_get_hf_tokenizer_config.return_value
272+
)
273+
mock_get_hf_tokenizer_config.assert_called_with("test_model_id")
274+
275+
@patch.object(AquaModelApp, "get_hf_tokenizer_config")
276+
@patch("ads.aqua.extension.model_handler.urlparse")
277+
def test_get_invalid_path(self, mock_urlparse, mock_get_hf_tokenizer_config):
278+
"""Test invalid request path should raise HTTPError(400)"""
279+
request_path = MagicMock(path="/invalid/path")
280+
mock_urlparse.return_value = request_path
281+
282+
with self.assertRaises(HTTPError) as context:
283+
self.model_tokenizer_config_handler.get(model_id="test_model_id")
284+
self.assertEqual(context.exception.status_code, 400)
285+
self.model_tokenizer_config_handler.finish.assert_not_called()
286+
mock_get_hf_tokenizer_config.assert_not_called()
287+
288+
253289
class TestAquaHuggingFaceHandler:
254290
def setup_method(self):
255291
with patch.object(IPythonHandler, "__init__"):

0 commit comments

Comments
 (0)