Skip to content

Commit 0db36f2

Browse files
Adding API to fetch tokenizer config for model (#1052)
2 parents 7e9fc1b + be96c70 commit 0db36f2

File tree

8 files changed

+114
-11
lines changed

8 files changed

+114
-11
lines changed

ads/aqua/app.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
import os
77
import traceback
88
from dataclasses import fields
9-
from typing import Dict, Union
9+
from typing import Dict, Optional, Union
1010

1111
import oci
1212
from oci.data_science.models import UpdateModelDetails, UpdateModelProvenanceDetails
1313

1414
from ads import set_auth
1515
from ads.aqua import logger
16-
from ads.aqua.common.enums import Tags
16+
from ads.aqua.common.enums import ConfigFolder, Tags
1717
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1818
from ads.aqua.common.utils import (
1919
_is_valid_mvs,
@@ -268,7 +268,12 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
268268
logger.info(f"Artifact not found in model {model_id}.")
269269
return False
270270

271-
def get_config(self, model_id: str, config_file_name: str) -> Dict:
271+
def get_config(
272+
self,
273+
model_id: str,
274+
config_file_name: str,
275+
config_folder: Optional[str] = ConfigFolder.CONFIG,
276+
) -> Dict:
272277
"""Gets the config for the given Aqua model.
273278
274279
Parameters
@@ -277,12 +282,17 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict:
277282
The OCID of the Aqua model.
278283
config_file_name: str
279284
name of the config file
285+
config_folder: (str, optional):
286+
subfolder path where config_file_name needs to be searched
287+
Defaults to `ConfigFolder.CONFIG`.
288+
When searching inside model artifact directory , the value is ConfigFolder.ARTIFACT`
280289
281290
Returns
282291
-------
283292
Dict:
284293
A dict of allowed configs.
285294
"""
295+
config_folder = config_folder or ConfigFolder.CONFIG
286296
oci_model = self.ds_client.get_model(model_id).data
287297
oci_aqua = (
288298
(
@@ -304,22 +314,25 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict:
304314
f"Base model found for the model: {oci_model.id}. "
305315
f"Loading {config_file_name} for base model {base_model_ocid}."
306316
)
307-
base_model = self.ds_client.get_model(base_model_ocid).data
308-
artifact_path = get_artifact_path(base_model.custom_metadata_list)
317+
if config_folder == ConfigFolder.ARTIFACT:
318+
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
319+
else:
320+
base_model = self.ds_client.get_model(base_model_ocid).data
321+
artifact_path = get_artifact_path(base_model.custom_metadata_list)
309322
else:
310323
logger.info(f"Loading {config_file_name} for model {oci_model.id}...")
311324
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
312-
313325
if not artifact_path:
314326
logger.debug(
315327
f"Failed to get artifact path from custom metadata for the model: {model_id}"
316328
)
317329
return config
318330

319-
config_path = f"{os.path.dirname(artifact_path)}/config/"
331+
config_path = os.path.join(os.path.dirname(artifact_path), config_folder)
320332
if not is_path_exists(config_path):
321-
config_path = f"{artifact_path.rstrip('/')}/config/"
322-
333+
config_path = os.path.join(artifact_path.rstrip("/"), config_folder)
334+
if not is_path_exists(config_path):
335+
config_path = f"{artifact_path.rstrip('/')}/"
323336
config_file_path = f"{config_path}{config_file_name}"
324337
if is_path_exists(config_file_path):
325338
try:

ads/aqua/common/enums.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,8 @@ class TextEmbeddingInferenceContainerParams(ExtendedEnum):
9292

9393
MODEL_ID = "model-id"
9494
PORT = "port"
95+
96+
97+
class ConfigFolder(ExtendedEnum):
98+
CONFIG = "config"
99+
ARTIFACT = "artifact"

ads/aqua/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
README = "README.md"
1111
LICENSE_TXT = "config/LICENSE.txt"
1212
DEPLOYMENT_CONFIG = "deployment_config.json"
13+
AQUA_MODEL_TOKENIZER_CONFIG = "tokenizer_config.json"
1314
COMPARTMENT_MAPPING_KEY = "service-model-compartment"
1415
CONTAINER_INDEX = "container_index.json"
1516
EVALUATION_REPORT_JSON = "report.json"

ads/aqua/extension/model_handler.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1515
from ads.aqua.common.utils import (
1616
get_hf_model_info,
17+
is_valid_ocid,
1718
list_hf_models,
1819
)
1920
from ads.aqua.extension.base_handler import AquaAPIhandler
@@ -316,8 +317,30 @@ def post(self, *args, **kwargs): # noqa: ARG002
316317
)
317318

318319

320+
class AquaModelTokenizerConfigHandler(AquaAPIhandler):
321+
def get(self, model_id):
322+
"""
323+
Handles requests for retrieving the Hugging Face tokenizer configuration of a specified model.
324+
Expected request format: GET /aqua/models/<model-ocid>/tokenizer
325+
326+
"""
327+
328+
path_list = urlparse(self.request.path).path.strip("/").split("/")
329+
# Path should be /aqua/models/ocid1.iad.ahdxxx/tokenizer
330+
# path_list=['aqua','models','<model-ocid>','tokenizer']
331+
if (
332+
len(path_list) == 4
333+
and is_valid_ocid(path_list[2])
334+
and path_list[3] == "tokenizer"
335+
):
336+
return self.finish(AquaModelApp().get_hf_tokenizer_config(model_id))
337+
338+
raise HTTPError(400, f"The request {self.request.path} is invalid.")
339+
340+
319341
__handlers__ = [
320342
("model/?([^/]*)", AquaModelHandler),
321343
("model/?([^/]*)/license", AquaModelLicenseHandler),
344+
("model/?([^/]*)/tokenizer", AquaModelTokenizerConfigHandler),
322345
("model/hf/search/?([^/]*)", AquaHuggingFaceHandler),
323346
]

ads/aqua/model/model.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger
1616
from ads.aqua.app import AquaApp
1717
from ads.aqua.common.enums import (
18+
ConfigFolder,
1819
CustomInferenceContainerTypeFamily,
1920
FineTuningContainerTypeFamily,
2021
InferenceContainerTypeFamily,
@@ -44,6 +45,7 @@
4445
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME,
4546
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE,
4647
AQUA_MODEL_ARTIFACT_FILE,
48+
AQUA_MODEL_TOKENIZER_CONFIG,
4749
AQUA_MODEL_TYPE_CUSTOM,
4850
HF_METADATA_FOLDER,
4951
LICENSE_TXT,
@@ -568,6 +570,26 @@ def _build_ft_metrics(
568570
training_final,
569571
]
570572

573+
def get_hf_tokenizer_config(self, model_id):
574+
"""Gets the default chat template for the given Aqua model.
575+
576+
Parameters
577+
----------
578+
model_id: str
579+
The OCID of the Aqua model.
580+
581+
Returns
582+
-------
583+
str:
584+
Chat template string.
585+
"""
586+
config = self.get_config(
587+
model_id, AQUA_MODEL_TOKENIZER_CONFIG, ConfigFolder.ARTIFACT
588+
)
589+
if not config:
590+
logger.debug(f"Tokenizer config for model: {model_id} is not available.")
591+
return config
592+
571593
@staticmethod
572594
def to_aqua_model(
573595
model: Union[

ads/aqua/modeldeployment/entities.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class AquaDeployment(DataClassSerializable):
4141
id: str = None
4242
display_name: str = None
4343
aqua_service_model: bool = None
44+
model_id: str = None
4445
aqua_model_name: str = None
4546
state: str = None
4647
description: str = None
@@ -97,7 +98,7 @@ def from_oci_model_deployment(
9798
else None
9899
),
99100
)
100-
101+
model_id = oci_model_deployment._model_deployment_configuration_details.model_configuration_details.model_id
101102
tags = {}
102103
tags.update(oci_model_deployment.freeform_tags or UNKNOWN_DICT)
103104
tags.update(oci_model_deployment.defined_tags or UNKNOWN_DICT)
@@ -110,6 +111,7 @@ def from_oci_model_deployment(
110111

111112
return AquaDeployment(
112113
id=oci_model_deployment.id,
114+
model_id=model_id,
113115
display_name=oci_model_deployment.display_name,
114116
aqua_service_model=aqua_service_model_tag is not None,
115117
aqua_model_name=aqua_model_name,

tests/unitary/with_extras/aqua/test_deployment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ class TestDataset:
254254
"created_by": "ocid1.user.oc1..<OCID>",
255255
"endpoint": MODEL_DEPLOYMENT_URL,
256256
"private_endpoint_id": null,
257+
"model_id": "ocid1.datasciencemodel.oc1.<region>.<OCID>",
257258
"environment_variables": {
258259
"BASE_MODEL": "service_models/model-name/artifact",
259260
"MODEL_DEPLOY_ENABLE_STREAMING": "true",

tests/unitary/with_extras/aqua/test_model_handler.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pytest
1010
from huggingface_hub.hf_api import HfApi, ModelInfo
1111
from huggingface_hub.utils import GatedRepoError
12-
from notebook.base.handlers import IPythonHandler
12+
from notebook.base.handlers import IPythonHandler, HTTPError
1313
from parameterized import parameterized
1414

1515
from ads.aqua.common.errors import AquaRuntimeError
@@ -18,6 +18,7 @@
1818
AquaHuggingFaceHandler,
1919
AquaModelHandler,
2020
AquaModelLicenseHandler,
21+
AquaModelTokenizerConfigHandler,
2122
)
2223
from ads.aqua.model import AquaModelApp
2324
from ads.aqua.model.entities import AquaModel, AquaModelSummary, HFModelSummary
@@ -250,6 +251,41 @@ def test_get(self, mock_load_license):
250251
mock_load_license.assert_called_with("test_model_id")
251252

252253

254+
class ModelTokenizerConfigHandlerTestCase(TestCase):
255+
@patch.object(IPythonHandler, "__init__")
256+
def setUp(self, ipython_init_mock) -> None:
257+
ipython_init_mock.return_value = None
258+
self.model_tokenizer_config_handler = AquaModelTokenizerConfigHandler(
259+
MagicMock(), MagicMock()
260+
)
261+
self.model_tokenizer_config_handler.finish = MagicMock()
262+
self.model_tokenizer_config_handler.request = MagicMock()
263+
264+
@patch.object(AquaModelApp, "get_hf_tokenizer_config")
265+
@patch("ads.aqua.extension.model_handler.urlparse")
266+
def test_get(self, mock_urlparse, mock_get_hf_tokenizer_config):
267+
request_path = MagicMock(path="aqua/model/ocid1.xx./tokenizer")
268+
mock_urlparse.return_value = request_path
269+
self.model_tokenizer_config_handler.get(model_id="test_model_id")
270+
self.model_tokenizer_config_handler.finish.assert_called_with(
271+
mock_get_hf_tokenizer_config.return_value
272+
)
273+
mock_get_hf_tokenizer_config.assert_called_with("test_model_id")
274+
275+
@patch.object(AquaModelApp, "get_hf_tokenizer_config")
276+
@patch("ads.aqua.extension.model_handler.urlparse")
277+
def test_get_invalid_path(self, mock_urlparse, mock_get_hf_tokenizer_config):
278+
"""Test invalid request path should raise HTTPError(400)"""
279+
request_path = MagicMock(path="/invalid/path")
280+
mock_urlparse.return_value = request_path
281+
282+
with self.assertRaises(HTTPError) as context:
283+
self.model_tokenizer_config_handler.get(model_id="test_model_id")
284+
self.assertEqual(context.exception.status_code, 400)
285+
self.model_tokenizer_config_handler.finish.assert_not_called()
286+
mock_get_hf_tokenizer_config.assert_not_called()
287+
288+
253289
class TestAquaHuggingFaceHandler:
254290
def setup_method(self):
255291
with patch.object(IPythonHandler, "__init__"):

0 commit comments

Comments
 (0)